From 3daaded034a72d6f31423f5b82d2a1ade8b4c952 Mon Sep 17 00:00:00 2001 From: Justus Tumacder Date: Tue, 28 Feb 2023 11:03:22 +0800 Subject: [PATCH] Add the ability to check claimed nameplates - Adds `list_nameplates` which returns a list of currently claimed nameplates - Adds a new argument (`expect_claimed_nameplate`) to `connect_with_code`. When true, the function will return an error if the nameplate is not claimed --- cli/src/main.rs | 5 +++-- src/core.rs | 12 ++++++++++++ src/core/rendezvous.rs | 21 +++++++++++++++++++++ src/core/test.rs | 38 ++++++++++++++++++++++++++++++++------ 4 files changed, 68 insertions(+), 8 deletions(-) diff --git a/cli/src/main.rs b/cli/src/main.rs index e7d07b1a..476a037a 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -656,7 +656,7 @@ async fn parse_and_connect( )?; } let (server_welcome, wormhole) = - magic_wormhole::Wormhole::connect_with_code(app_config, code).await?; + magic_wormhole::Wormhole::connect_with_code(app_config, code, false).await?; print_welcome(term, &server_welcome)?; (wormhole, server_welcome.code) }, @@ -860,7 +860,8 @@ async fn send_many( } let (_server_welcome, wormhole) = - magic_wormhole::Wormhole::connect_with_code(transfer::APP_CONFIG, code.clone()).await?; + magic_wormhole::Wormhole::connect_with_code(transfer::APP_CONFIG, code.clone(), false) + .await?; send_in_background( relay_hints.clone(), Arc::clone(&file_path), diff --git a/src/core.rs b/src/core.rs index 27bf0d3c..282ad58a 100644 --- a/src/core.rs +++ b/src/core.rs @@ -42,6 +42,8 @@ pub enum WormholeError { PakeFailed, #[error("Cannot decrypt a received message")] Crypto, + #[error("Nameplate is unclaimed: {}", _0)] + UnclaimedNameplate(Nameplate), } impl WormholeError { @@ -165,6 +167,7 @@ impl Wormhole { pub async fn connect_with_code( config: AppConfig, code: Code, + expect_claimed_nameplate: bool, ) -> Result<(WormholeWelcome, Self), WormholeError> { let AppConfig { id: appid, @@ -174,6 +177,15 @@ impl Wormhole { let (mut server, welcome) = RendezvousServer::connect(&appid, &rendezvous_url).await?; let nameplate = code.nameplate(); + + if expect_claimed_nameplate { + let nameplate = code.nameplate(); + let nameplates = server.list_nameplates().await?; + if !nameplates.contains(&nameplate) { + return Err(WormholeError::UnclaimedNameplate(nameplate)); + } + } + let mailbox = server.claim_open(nameplate).await?; log::debug!("Connected to mailbox {}", mailbox); diff --git a/src/core/rendezvous.rs b/src/core/rendezvous.rs index b013877b..a107297c 100644 --- a/src/core/rendezvous.rs +++ b/src/core/rendezvous.rs @@ -74,6 +74,10 @@ impl RendezvousError { type MessageQueue = VecDeque; +#[derive(Clone, Debug, derive_more::Display)] +#[display(fmt = "{:?}", _0)] +struct NameplateList(Vec); + #[cfg(not(target_family = "wasm"))] struct WsConnection { connection: async_tungstenite::WebSocketStream, @@ -174,6 +178,9 @@ impl WsConnection { Some(InboundMessage::Error { error, orig: _ }) => { break Err(RendezvousError::Server(error.into())); }, + Some(InboundMessage::Nameplates { nameplates }) => { + break Ok(RendezvousReply::Nameplates(NameplateList(nameplates))) + }, Some(other) => { break Err(RendezvousError::protocol(format!( "Got unexpected message type from server '{}'", @@ -273,6 +280,7 @@ enum RendezvousReply { Released, Claimed(Mailbox), Closed, + Nameplates(NameplateList), } #[derive(Clone, Debug, derive_more::Display)] @@ -528,6 +536,19 @@ impl RendezvousServer { .is_some() } + /** + * Gets the list of currently claimed nameplates. + * This can be called at any time. + */ + pub async fn list_nameplates(&mut self) -> Result, RendezvousError> { + self.send_message(&OutboundMessage::List).await?; + let nameplate_reply = self.receive_reply().await?; + match nameplate_reply { + RendezvousReply::Nameplates(x) => Ok(x.0), + other => Err(RendezvousError::invalid_message("nameplates", other)), + } + } + pub async fn release_nameplate(&mut self) -> Result<(), RendezvousError> { let nameplate = &mut self .state diff --git a/src/core/test.rs b/src/core/test.rs index b060a3dc..5e9d99ed 100644 --- a/src/core/test.rs +++ b/src/core/test.rs @@ -75,7 +75,8 @@ pub async fn test_file_rust2rust() -> eyre::Result<()> { let code = code_rx.await?; log::info!("Got code over local: {}", &code); let (welcome, wormhole) = - Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code).await?; + Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code, true) + .await?; if let Some(welcome) = &welcome.welcome { log::info!("Got welcome: {}", welcome); } @@ -150,7 +151,8 @@ pub async fn test_4096_file_rust2rust() -> eyre::Result<()> { let code = code_rx.await?; log::info!("Got code over local: {}", &code); let (welcome, wormhole) = - Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code).await?; + Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code, true) + .await?; if let Some(welcome) = &welcome.welcome { log::info!("Got welcome: {}", welcome); } @@ -223,7 +225,8 @@ pub async fn test_empty_file_rust2rust() -> eyre::Result<()> { let code = code_rx.await?; log::info!("Got code over local: {}", &code); let (welcome, wormhole) = - Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code).await?; + Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code, true) + .await?; if let Some(welcome) = &welcome.welcome { log::info!("Got welcome: {}", welcome); } @@ -302,6 +305,7 @@ pub async fn test_send_many() -> eyre::Result<()> { let (_welcome, wormhole) = Wormhole::connect_with_code( transfer::APP_CONFIG.id(TEST_APPID), sender_code.clone(), + false, ) .await?; senders.push(async_std::task::spawn(async move { @@ -329,7 +333,8 @@ pub async fn test_send_many() -> eyre::Result<()> { for i in 0..5usize { log::info!("Receiving file #{}", i); let (_welcome, wormhole) = - Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code.clone()).await?; + Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code.clone(), true) + .await?; log::info!("Got key: {}", &wormhole.key); let req = crate::transfer::request_file( wormhole, @@ -389,6 +394,7 @@ pub async fn test_wrong_code() -> eyre::Result<()> { APP_CONFIG, /* Making a wrong code here by appending bullshit */ Code::new(&nameplate, "foo-bar"), + true, ) .await; @@ -411,9 +417,9 @@ pub async fn test_crowded() -> eyre::Result<()> { let (welcome, connector1) = Wormhole::connect_without_code(APP_CONFIG, 2).await?; log::info!("This test's code is: {}", &welcome.code); - let connector2 = Wormhole::connect_with_code(APP_CONFIG, welcome.code.clone()); + let connector2 = Wormhole::connect_with_code(APP_CONFIG, welcome.code.clone(), true); - let connector3 = Wormhole::connect_with_code(APP_CONFIG, welcome.code.clone()); + let connector3 = Wormhole::connect_with_code(APP_CONFIG, welcome.code.clone(), true); match futures::try_join!(connector1, connector2, connector3).unwrap_err() { magic_wormhole::WormholeError::ServerError( @@ -427,6 +433,26 @@ pub async fn test_crowded() -> eyre::Result<()> { Ok(()) } +#[async_std::test] +pub async fn test_connect_with_code_expecting_nameplate() -> eyre::Result<()> { + // the max nameplate number is 999, so this will not impact a real nameplate + let code = Code("1000-guitarist-revenge".to_owned()); + let connector = Wormhole::connect_with_code(APP_CONFIG, code, true) + .await + .unwrap_err(); + match connector { + magic_wormhole::WormholeError::UnclaimedNameplate(x) => { + assert_eq!(x, magic_wormhole::core::Nameplate("1000".to_owned())); + }, + other => panic!( + "Got wrong error type {:?}. Expected `NameplateNotFound`", + other + ), + } + + Ok(()) +} + #[test] fn test_phase() { let p = Phase::PAKE;