Skip to content

Commit

Permalink
Added tests to contact exchange protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
manforowicz committed Mar 24, 2024
1 parent 5290335 commit 8579675
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 68 deletions.
19 changes: 16 additions & 3 deletions gday/src/dialog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ pub fn confirm_receive(files: &[FileMeta]) -> std::io::Result<Vec<Option<u64>>>
let mut new_size = 0;
let mut total_size = 0;

println!("Your mate wants to send you {} files:", files.len().bold());
println!(
"{} {} {}",
"Your mate wants to send you".bold(),
files.len().bold(),
"files:".bold()
);
for file in files {
// print file metadata
print!("{} ({})", file.short_path.display(), HumanBytes(file.len));
Expand All @@ -32,8 +37,15 @@ pub fn confirm_receive(files: &[FileMeta]) -> std::io::Result<Vec<Option<u64>>>

// an interrupted download exists
} else if let Some(local_len) = interrupted_exists(file)? {
print!(" {}", "INTERRUPTED".bold());
new_size += file.len - local_len;
let remaining_len = file.len - local_len;

print!(
" {} {} {}",
"PARTIALLY DOWNLOADED.".bold(),
HumanBytes(remaining_len).bold(),
"REMAINING".bold()
);
new_size += remaining_len;
new_files.push(Some(local_len));

// this file does not exist
Expand All @@ -59,6 +71,7 @@ pub fn confirm_receive(files: &[FileMeta]) -> std::io::Result<Vec<Option<u64>>>
println!("1. Download all files.");
println!("2. Download only files with new path or size. Resume any interrupted downloads.");
println!("3. Cancel.");
println!("Note: Gday will create new files, instead of overwriting existing files.");
print!("{} ", "Choose an option (1, 2, or 3):".bold());
std::io::stdout().flush()?;

Expand Down
10 changes: 9 additions & 1 deletion gday/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ fn run(args: Args) -> Result<(), Box<dyn std::error::Error>> {
// get peer's contact
let peer_contact = contact_sharer.get_peer_contact()?;
info!("Your mate's contact is:\n{peer_contact}");
info!("Trying TCP hole punching.");

// connect to the peer
let (stream, shared_key) = gday_hole_punch::try_connect_to_peer(
Expand All @@ -145,9 +146,13 @@ fn run(args: Args) -> Result<(), Box<dyn std::error::Error>> {
// offer these files to the peer
serialize_into(FileOfferMsg { files }, &mut stream)?;

info!("Waiting for peer to respond to file offer.");

// receive file offer from peer
let response: FileResponseMsg = deserialize_from(&mut stream, &mut Vec::new())?;

info!("Starting file send.");

transfer::send_files(&mut stream, &local_files, &response.accepted)?;
}

Expand All @@ -162,6 +167,7 @@ fn run(args: Args) -> Result<(), Box<dyn std::error::Error>> {
let peer_contact = contact_sharer.get_peer_contact()?;

info!("Your mate's contact is:\n{peer_contact}");
info!("Trying TCP hole punching.");

let (stream, shared_key) = gday_hole_punch::try_connect_to_peer(
my_contact.private,
Expand All @@ -187,11 +193,13 @@ fn run(args: Args) -> Result<(), Box<dyn std::error::Error>> {
&mut stream,
)?;

info!("Starting file download.");

transfer::receive_files(&mut stream, &offer.files, &accepted)?;
}
}

println!("{}", "Success!".bold().green());
println!("{}", "Done!".bold().green());

Ok(())
}
15 changes: 9 additions & 6 deletions gday/src/transfer.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use gday_file_offer_protocol::{FileMeta, FileMetaLocal};
use crate::TMP_DOWNLOAD_PREFIX;
use gday_encryption::EncryptedStream;
use gday_file_offer_protocol::{FileMeta, FileMetaLocal};
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};

use crate::TMP_DOWNLOAD_PREFIX;
const FILE_BUFFER_SIZE: usize = 1_000_000;

/// Wrap a [`TcpStream`] in a [`gday_encryption::EncryptedStream`].
pub fn encrypt_connection<T: Read + Write>(
Expand Down Expand Up @@ -58,7 +59,8 @@ pub fn send_files(
progress.set_message(format!("sending {msg}"));

// copy the file into the writer
let mut file = File::open(&meta.local_path)?;
let mut file =
BufReader::with_capacity(FILE_BUFFER_SIZE, File::open(&meta.local_path)?);
// TODO: maybe check if file length is correct?

file.seek(SeekFrom::Start(start))?;
Expand Down Expand Up @@ -118,7 +120,7 @@ pub fn receive_files(
let tmp_path = meta.get_prefixed_save_path(TMP_DOWNLOAD_PREFIX.into())?;

// create the temporary download file
let mut file = File::create(&tmp_path)?;
let mut file = BufWriter::with_capacity(FILE_BUFFER_SIZE, File::create(&tmp_path)?);

// only take the length of the file from the reader
let mut reader = reader.take(meta.len);
Expand All @@ -145,7 +147,8 @@ pub fn receive_files(

// open the partially downloaded file in append mode
let tmp_path = meta.get_prefixed_save_path(TMP_DOWNLOAD_PREFIX.into())?;
let mut file = OpenOptions::new().append(true).open(&tmp_path).unwrap();
let file = OpenOptions::new().append(true).open(&tmp_path)?;
let mut file = BufWriter::with_capacity(FILE_BUFFER_SIZE, file);

// only take the length of the remaining part of the file from the reader
let mut reader = reader.take(meta.len - start);
Expand Down
48 changes: 35 additions & 13 deletions gday_contact_exchange_protocol/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
//! This protocol lets two users exchange their public and (optionally) private socket addresses via a server.
//! On it's own, this crate doesn't do anything other than define a shared protocol.
//! This is done with the following process:
//! On it's own, this crate doesn't do anything other than define a shared protocol, and functions to
//! send and receive messages of this protocol.
//!
//! # Process
//!
//! Using this protocol goes something like this:
//!
//! 1. `peer A` connects to a server via the internet
//! and requests a new room using [`ClientMsg::CreateRoom`].
//! and requests a new room with `room_code` using [`ClientMsg::CreateRoom`].
//!
//! 2. The server replies to `peer A` with a random unused room code via [`ServerMsg::RoomCreated`].
//! 2. The server replies to `peer A` with [`ServerMsg::RoomCreated`] or [`ServerMsg::ErrorRoomTaken`]
//! depending on if this `room_code` is in use.
//!
//! 3. `peer A` externally tells `peer B` this code (by phone call, text message, carrier pigeon, etc.).
//! 3. `peer A` externally tells `peer B` their `room_code` (by phone call, text message, carrier pigeon, etc.).
//!
//! 4. Both peers send this room code and optionally their local/private socket addresses to the server
//! 4. Both peers send this `room_code` and optionally their local/private socket addresses to the server
//! via [`ClientMsg::SendAddr`] messages. The server determines their public addresses from the internet connections.
//! The server replies with [`ServerMsg::ReceivedAddr`] after each of these messages.
//!
//! 5. Both peers send [`ClientMsg::DoneSending`] once they are ready to receive the contact info of each other.
//!
//! 6. The server sends each peer their contact information via [`ServerMsg::ClientContact`]
//! 6. The server immediately replies to [`ClientMsg::DoneSending`]
//! with [`ServerMsg::ClientContact`] which contains the [`FullContact`] of this peer.
//!
//! 7. Once both peers are ready, the server sends each peer the public and private socket addresses
//! of the other peer via [`ServerMsg::PeerContact`].
//! 7. Once both peers are ready, the server sends (on the same stream where [`ClientMsg::DoneSending`] came from)
//! each peer [`ServerMsg::PeerContact`] which contains the [`FullContact`] of the other peer..
//!
//! 8. On their own, the peers use this info to connect directly to each other by using
//! [hole punching](https://en.wikipedia.org/wiki/Hole_punching_(networking)).
//!
//! #
#![forbid(unsafe_code)]
#![warn(clippy::all)]

Expand Down Expand Up @@ -179,11 +187,16 @@ impl std::fmt::Display for FullContact {
// The max size of a `ServerMsg` is
// 5 + 116 = 121 bytes
//
// This means the length can be represented with a single u8.
pub const MAX_MSG_SIZE: usize = 121;
// This means the length of a message can be represented with a single u8.

/// The maximum length of a serialized message.
/// Constrained by the max value of the message length byte header.
pub const MAX_MSG_SIZE: usize = u8::MAX as usize;

/// Write `msg` to `writer` using [`postcard`].
/// Prefixes the message with a byte that holds its length.
pub fn serialize_into(msg: impl Serialize, writer: &mut impl Write) -> Result<(), Error> {
let mut buf = [0_u8; 255];
let mut buf = [0_u8; MAX_MSG_SIZE];
let len = to_slice(&msg, &mut buf[1..])?.len();
let len_byte = u8::try_from(len).expect("Unreachable: Message always shorter than u8::MAX");
buf[0] = len_byte;
Expand All @@ -192,6 +205,8 @@ pub fn serialize_into(msg: impl Serialize, writer: &mut impl Write) -> Result<()
Ok(())
}

/// Write `msg` to `writer` using [`postcard`].
/// Prefixes the message with a byte that holds its length.
pub fn deserialize_from<'a, T: Deserialize<'a>>(
reader: &mut impl Read,
buf: &'a mut [u8],
Expand All @@ -203,19 +218,23 @@ pub fn deserialize_from<'a, T: Deserialize<'a>>(
Ok(from_bytes(&buf[0..len])?)
}

/// Asynchronously write `msg` to `writer` using [`postcard`].
/// Prefixes the message with a byte that holds its length.
pub async fn serialize_into_async(
msg: impl Serialize,
writer: &mut (impl AsyncWrite + Unpin),
) -> Result<(), Error> {
let mut buf = [0_u8; MAX_MSG_SIZE];
let len = to_slice(&msg, &mut buf[1..])?.len();
let len_byte = u8::try_from(len).expect("Unreachable: Message always shorter than u8::MAX");
let len_byte = u8::try_from(len)?;
buf[0] = len_byte;
writer.write_all(&buf[0..len + 1]).await?;
writer.flush().await?;
Ok(())
}

/// Asynchronously write `msg` to `writer` using [`postcard`].
/// Prefixes the message with a byte that holds its length.
pub async fn deserialize_from_async<'a, T: Deserialize<'a>>(
reader: &mut (impl AsyncRead + Unpin),
buf: &'a mut [u8],
Expand All @@ -238,4 +257,7 @@ pub enum Error {
/// IO Error sending or receiving a message
#[error("IO Error: {0}")]
IO(#[from] std::io::Error),

#[error("Message longer than max of 256 bytes.")]
MsgTooLong(#[from] std::num::TryFromIntError),
}
129 changes: 91 additions & 38 deletions gday_contact_exchange_protocol/src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,55 +1,108 @@
#![cfg(test)]
use crate::{ClientMsg, Contact, FullContact, ServerMsg};

// TODO: REWRITE
/*
#[tokio::test]
async fn messenger_send_1() {
let (mut stream_a, mut stream_b) = tokio::io::duplex(1000);
let mut messenger_a = AsyncMessenger::new(&mut stream_a);
let mut messenger_b = AsyncMessenger::new(&mut stream_b);
/// Test serializing and deserializing messages.
#[test]
fn sending_messages() {
let mut bytes = std::collections::VecDeque::new();

let sent = ServerMsg::ErrorNoSuchRoomID;
for msg in get_client_msg_examples() {
crate::serialize_into(msg, &mut bytes).unwrap();
}

messenger_a.send(&sent).await.unwrap();
for msg in get_client_msg_examples() {
let mut buf = [0; crate::MAX_MSG_SIZE];
let deserialized_msg: ClientMsg = crate::deserialize_from(&mut bytes, &mut buf).unwrap();
assert_eq!(msg, deserialized_msg);
}

let received: ServerMsg = messenger_b.receive().await.unwrap();
for msg in get_server_msg_examples() {
crate::serialize_into(msg, &mut bytes).unwrap();
}

assert_eq!(sent, received);
for msg in get_server_msg_examples() {
let mut buf = [0; crate::MAX_MSG_SIZE];
let deserialized_msg: ServerMsg = crate::deserialize_from(&mut bytes, &mut buf).unwrap();
assert_eq!(msg, deserialized_msg);
}
}

/// Test serializing and deserializing messages asynchronously.
#[tokio::test]
async fn messenger_send_2() {
let (mut stream_a, mut stream_b) = tokio::io::duplex(1000);
let mut messenger_a = AsyncMessenger::new(&mut stream_a);
let mut messenger_b = AsyncMessenger::new(&mut stream_b);
async fn sending_messages_async() {
let (mut writer, mut reader) = tokio::io::duplex(1000);

let socket = SocketAddr::V6(SocketAddrV6::new(578674694309532.into(), 1456, 0, 0));
for msg in get_client_msg_examples() {
crate::serialize_into_async(msg, &mut writer).await.unwrap();
}

let sent = ClientMsg::SendAddr {
room_code: 65721,
is_creator: false,
private_addr: Some(socket),
};
for msg in get_client_msg_examples() {
let mut buf = [0; crate::MAX_MSG_SIZE];
let deserialized_msg: ClientMsg = crate::deserialize_from_async(&mut reader, &mut buf)
.await
.unwrap();
assert_eq!(msg, deserialized_msg);
}

messenger_a.send(&sent).await.unwrap();
for msg in get_server_msg_examples() {
crate::serialize_into_async(msg, &mut writer).await.unwrap();
}

let received: ClientMsg = messenger_b.receive().await.unwrap();
for msg in get_server_msg_examples() {
let mut buf = [0; crate::MAX_MSG_SIZE];
let deserialized_msg: ServerMsg = crate::deserialize_from_async(&mut reader, &mut buf)
.await
.unwrap();
assert_eq!(msg, deserialized_msg);
}
}

assert_eq!(sent, received);
/// Get a [`Vec`] of example [`ClientMsg`]s.
fn get_client_msg_examples() -> Vec<ClientMsg> {
vec![
ClientMsg::CreateRoom { room_code: 452932 },
ClientMsg::SendAddr {
room_code: 2345,
is_creator: true,
private_addr: Some("31.31.65.31:324".parse().unwrap()),
},
ClientMsg::DoneSending {
room_code: 24325423,
is_creator: false,
},
]
}

#[tokio::test]
async fn messenger_invalid_data() {
let (mut stream_a, mut stream_b) = tokio::io::duplex(1000);
// gibberish data
stream_a
.write_all(&[0, 12, 53, 24, 85, 52, 24, 123, 32, 52, 52, 52, 13, 35])
.await
.unwrap();
let mut messenger_b = AsyncMessenger::new(&mut stream_b);
let result: Result<ServerMsg, Error> = messenger_b.receive().await;
assert!(result.is_err());
/// Get a [`Vec`] of example [`ServerMsg`]s.
fn get_server_msg_examples() -> Vec<ServerMsg> {
vec![
ServerMsg::RoomCreated,
ServerMsg::ReceivedAddr,
ServerMsg::ClientContact(FullContact {
private: Contact {
v4: Some("31.31.65.31:324".parse().unwrap()),
v6: Some("[2001:db8::1]:8080".parse().unwrap()),
},
public: Contact {
v4: Some("31.31.65.31:324".parse().unwrap()),
v6: Some("[2001:db8::1]:8080".parse().unwrap()),
},
}),
ServerMsg::PeerContact(FullContact {
private: Contact {
v4: Some("31.31.65.31:324".parse().unwrap()),
v6: Some("[2001:db8::1]:8080".parse().unwrap()),
},
public: Contact {
v4: Some("31.31.65.31:324".parse().unwrap()),
v6: Some("[2001:db8::1]:8080".parse().unwrap()),
},
}),
ServerMsg::ErrorRoomTaken,
ServerMsg::ErrorPeerTimedOut,
ServerMsg::ErrorNoSuchRoomID,
ServerMsg::ErrorTooManyRequests,
ServerMsg::SyntaxError,
ServerMsg::ConnectionError,
]
}
*/
2 changes: 1 addition & 1 deletion gday_encryption/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
};

#[test]
fn test_all() {
fn test_small_messages() {
let nonce = [5; 7];
let key = [5; 32];

Expand Down
2 changes: 1 addition & 1 deletion gday_hole_punch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ rustls = "0.23.3"
socket2 = "0.5.6"
spake2 = { version = "0.4.0", features = ["std"] }
thiserror = "1.0.58"
tokio = { version = "1.36.0", features = ["net"] }
tokio = { version = "1.36.0", features = ["net", "rt", "time"] }
webpki-roots = "0.26.1"
Loading

0 comments on commit 8579675

Please sign in to comment.