diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 92f4bc80..200b3d28 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -20,7 +20,7 @@ clap_complete = "3.1.4" env_logger = "0.10.0" console = "0.15.0" indicatif = "0.17.0" -dialoguer = "0.10.0" +dialoguer = { version = "0.10", features = ["completion"] } color-eyre = "0.6.0" number_prefix = "0.4.0" ctrlc = "3.2.1" diff --git a/cli/src/main.rs b/cli/src/main.rs index 154e0d1d..b58b97a0 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -8,9 +8,14 @@ use clap::{Args, CommandFactory, Parser, Subcommand}; use cli_clipboard::{ClipboardContext, ClipboardProvider}; use color_eyre::{eyre, eyre::Context}; use console::{style, Term}; +use dialoguer::Input; use futures::{future::Either, Future, FutureExt}; use indicatif::{MultiProgress, ProgressBar}; -use std::{io::Write, path::PathBuf}; +use magic_wormhole::PgpWordList; +use std::{ + io::Write, + path::{Path, PathBuf}, +}; use magic_wormhole::{ dilated_transfer, forwarding, transfer, transit, MailboxConnection, Wormhole, @@ -570,7 +575,7 @@ async fn main() -> eyre::Result<()> { out, ); - std::io::stdout().write_all(&out.as_bytes())?; + std::io::stdout().write_all(out.as_bytes())?; }, shell => { let mut out = std::io::stdout(); @@ -761,12 +766,14 @@ fn create_progress_handler(pb: ProgressBar) -> impl FnMut(u64, u64) { } fn enter_code() -> eyre::Result { - use dialoguer::Input; - - Input::new() + let completion = WordList::default(); + let input = Input::new() .with_prompt("Enter code") + .completion_with(&completion) .interact_text() - .map_err(From::from) + .map_err(From::from); + + input } fn print_welcome(term: &mut Term, welcome: &Option) -> eyre::Result<()> { @@ -1059,15 +1066,14 @@ async fn receive_inner_v1( .truncate(true) .open(&file_path) .await?; - Ok(req - .accept( - &transit::log_transit_connection, - &mut file, - create_progress_handler(pb), - ctrl_c(), - ) - .await - .context("Receive process failed")?) + req.accept( + &transit::log_transit_connection, + &mut file, + create_progress_handler(pb), + ctrl_c(), + ) + .await + .context("Receive process failed") } async fn receive_inner_v2( @@ -1172,6 +1178,74 @@ async fn receive_inner_v2( Ok(()) } +use dialoguer::Completion; +use magic_wormhole::core::wordlist; +use std::{collections::HashMap, fs}; + +struct WordList(PgpWordList); + +impl Default for WordList { + fn default() -> Self { + let json = fs::read_to_string("./src/core/pgpwords.json").unwrap(); + let word_map: HashMap> = serde_json::from_str(&json).unwrap(); + let mut even_words: Vec = vec![]; + let mut odd_words: Vec = vec![]; + for (_idx, words) in word_map { + even_words.push(words[0].to_lowercase()); + odd_words.push(words[1].to_lowercase()); + } + let words = vec![even_words, odd_words]; + + WordList { + 0: PgpWordList { + words: words.clone(), + num_words: words.len(), + }, + } + } +} + +impl Completion for WordList { + fn get(&self, input: &str) -> Option { + let count_dashes = input.matches('-').count(); + let mut completions = Vec::new(); + let words = &self.0.words[count_dashes % self.0.words.len()]; + + let last_partial_word = input.split('-').last(); + let lp = if let Some(w) = last_partial_word { + w.len() + } else { + 0 + }; + + for word in words { + let mut suffix: String = input.to_owned(); + if word.starts_with(last_partial_word.unwrap()) { + if lp == 0 { + suffix.push_str(word); + } else { + let p = input.len() - lp; + suffix.truncate(p); + suffix.push_str(word); + } + + if count_dashes + 1 < self.0.num_words { + suffix.push('-'); + } + + completions.push(suffix); + } + } + if completions.len() == 1 { + Some(completions.first().unwrap().clone()) + } else { + // TODO: show vector of suggestions somehow + // println!("Suggestions: {:#?}", &completions); + None + } + } +} + #[cfg(test)] mod test { use super::*; @@ -1189,4 +1263,21 @@ mod test { String::from_utf8(out).unwrap(); } } + + #[test] + fn test_passphrase_completion() { + let words: Vec> = vec![ + wordlist::vecstrings("purple green yellow"), + wordlist::vecstrings("sausages seltzer snobol"), + ]; + + let w = WordList(PgpWordList { + words, + num_words: 2, + }); + assert_eq!(w.get(""), None); + assert_eq!(w.get("pur").unwrap(), "purple-"); + assert_eq!(w.get("blu"), None); + assert_eq!(w.get("purple-sa").unwrap(), "purple-sausages"); + } } diff --git a/cli/src/util.rs b/cli/src/util.rs index 51d517ea..d76450bf 100644 --- a/cli/src/util.rs +++ b/cli/src/util.rs @@ -20,7 +20,7 @@ pub async fn ask_user(message: impl std::fmt::Display, default_answer: bool) -> let mut answer = String::new(); stdin.read_line(&mut answer).await.unwrap(); - match &*answer.to_lowercase().trim() { + match answer.to_lowercase().trim() { "y" | "yes" => break true, "n" | "no" => break false, "" => break default_answer, diff --git a/src/core.rs b/src/core.rs index f99d53df..e7e9aca1 100644 --- a/src/core.rs +++ b/src/core.rs @@ -18,7 +18,7 @@ pub mod rendezvous; mod server_messages; #[cfg(test)] pub(crate) mod test; -mod wordlist; +pub mod wordlist; #[derive(Debug, thiserror::Error)] #[non_exhaustive] @@ -694,9 +694,9 @@ impl Nameplate { } } -impl Into for Nameplate { - fn into(self) -> String { - self.0 +impl From for String { + fn from(val: Nameplate) -> Self { + val.0 } } @@ -723,7 +723,7 @@ impl Code { } pub fn nameplate(&self) -> Nameplate { - Nameplate::new(self.0.splitn(2, '-').next().unwrap()) + Nameplate::new(self.0.split('-').next().unwrap()) } } diff --git a/src/core/key.rs b/src/core/key.rs index 2fd0140a..431c943a 100644 --- a/src/core/key.rs +++ b/src/core/key.rs @@ -49,7 +49,7 @@ impl Key { */ #[cfg(feature = "transit")] pub fn derive_transit_key(&self, appid: &AppID) -> Key { - let transit_purpose = format!("{}/transit-key", &*appid); + let transit_purpose = format!("{}/transit-key", appid); let derived_key = self.derive_subkey_from_purpose(&transit_purpose); trace!( @@ -68,7 +68,7 @@ impl Key

{ } pub fn to_hex(&self) -> String { - hex::encode(&**self) + hex::encode(**self) } /** @@ -76,7 +76,7 @@ impl Key

{ */ pub fn derive_subkey_from_purpose(&self, purpose: &str) -> Key { Key( - Box::new(derive_key(&*self, purpose.as_bytes())), + Box::new(derive_key(self, purpose.as_bytes())), std::marker::PhantomData, ) } diff --git a/src/core/test.rs b/src/core/test.rs index 6123550b..2d852750 100644 --- a/src/core/test.rs +++ b/src/core/test.rs @@ -228,8 +228,9 @@ pub async fn test_file_rust2rust_deprecated() -> eyre::Result<()> { futures::future::pending(), ) .await? - .unwrap() - else {panic!("v2 should be disabled for now")}; + .unwrap() else { + panic!("v2 should be disabled for now") + }; req.accept( &transit::log_transit_connection, &mut answer, @@ -302,8 +303,9 @@ pub async fn test_file_rust2rust() -> eyre::Result<()> { futures::future::pending(), ) .await? - .unwrap() - else {panic!("v2 should be disabled for now")}; + .unwrap() else { + panic!("v2 should be disabled for now") + }; req.accept( &transit::log_transit_connection, &mut answer, @@ -413,8 +415,9 @@ pub async fn test_send_many() -> eyre::Result<()> { futures::future::pending(), ) .await? - .unwrap() - else {panic!("v2 should be disabled for now")}; + .unwrap() else { + panic!("v2 should be disabled for now") + }; // Hacky v1-compat conversion for now let mut answer = (gen_accept() diff --git a/src/core/wordlist.rs b/src/core/wordlist.rs index 34bf5cb0..30cf4dbb 100644 --- a/src/core/wordlist.rs +++ b/src/core/wordlist.rs @@ -1,63 +1,19 @@ use rand::{rngs::OsRng, seq::SliceRandom}; -use std::fmt; +use serde_json::{self, Value}; -#[derive(PartialEq)] -pub struct Wordlist { +pub struct PgpWordList { + pub words: Vec>, pub num_words: usize, - words: Vec>, } -impl fmt::Debug for Wordlist { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Wordlist ( {}, lots of words...)", self.num_words) +impl PgpWordList { + pub fn new(num_words: usize, words: Vec>) -> Self { + Self { words, num_words } } -} - -impl Wordlist { - #[cfg(test)] - pub fn new(num_words: usize, words: Vec>) -> Wordlist { - Wordlist { num_words, words } - } - - #[allow(dead_code)] // TODO make this API public one day - pub fn get_completions(&self, prefix: &str) -> Vec { - let count_dashes = prefix.matches('-').count(); - let mut completions = Vec::new(); - let words = &self.words[count_dashes % self.words.len()]; - - let last_partial_word = prefix.split('-').last(); - let lp = if let Some(w) = last_partial_word { - w.len() - } else { - 0 - }; - - for word in words { - let mut suffix: String = prefix.to_owned(); - if word.starts_with(last_partial_word.unwrap()) { - if lp == 0 { - suffix.push_str(&word); - } else { - let p = prefix.len() - lp; - suffix.truncate(p as usize); - suffix.push_str(&word); - } - - if count_dashes + 1 < self.num_words { - suffix.push_str("-"); - } - - completions.push(suffix); - } - } - completions.sort(); - completions - } - pub fn choose_words(&self) -> String { let mut rng = OsRng; - let components: Vec; - components = self + + let components: Vec = self .words .iter() .cycle() @@ -97,13 +53,25 @@ fn load_pgpwords() -> Vec> { vec![even_words, odd_words] } -pub fn default_wordlist(num_words: usize) -> Wordlist { - Wordlist { +pub fn default_wordlist(num_words: usize) -> PgpWordList { + PgpWordList { num_words, words: load_pgpwords(), } } +pub fn vecstrings(all: &str) -> Vec { + all.split_whitespace() + .map(|s| { + if s == "." { + String::from("") + } else { + s.to_string() + } + }) + .collect() +} + #[cfg(test)] mod test { use super::*; @@ -118,51 +86,15 @@ mod test { assert_eq!(w[1][255], "zulu"); } - #[test] - fn test_default_wordlist() { - let d = default_wordlist(2); - assert_eq!(d.words.len(), 2); - assert_eq!(d.words[0][0], "adroitness"); - assert_eq!(d.words[1][0], "aardvark"); - assert_eq!(d.words[0][255], "yucatan"); - assert_eq!(d.words[1][255], "zulu"); - } - - fn vecstrings(all: &str) -> Vec { - all.split_whitespace() - .map(|s| { - if s == "." { - String::from("") - } else { - s.to_string() - } - }) - .collect() - } - - #[test] - fn test_completion() { - let words: Vec> = vec![ - vecstrings("purple green yellow"), - vecstrings("sausages seltzer snobol"), - ]; - - let w = Wordlist::new(2, words); - assert_eq!(w.get_completions(""), vec!["green-", "purple-", "yellow-"]); - assert_eq!(w.get_completions("pur"), vec!["purple-"]); - assert_eq!(w.get_completions("blu"), Vec::::new()); - assert_eq!(w.get_completions("purple-sa"), vec!["purple-sausages"]); - } - #[test] fn test_choose_words() { let few_words: Vec> = vec![vecstrings("purple"), vecstrings("sausages")]; - let w = Wordlist::new(2, few_words.clone()); + let w = PgpWordList::new(2, few_words.clone()); assert_eq!(w.choose_words(), "purple-sausages"); - let w = Wordlist::new(3, few_words.clone()); + let w = PgpWordList::new(3, few_words.clone()); assert_eq!(w.choose_words(), "purple-sausages-purple"); - let w = Wordlist::new(4, few_words); + let w = PgpWordList::new(4, few_words); assert_eq!(w.choose_words(), "purple-sausages-purple-sausages"); } @@ -182,57 +114,14 @@ mod test { .map(|s| s.to_string()) .collect(); - let w = Wordlist::new(2, more_words.clone()); + let w = PgpWordList::new(2, more_words.clone()); for _ in 0..20 { assert!(expected2.contains(&w.choose_words())); } - let w = Wordlist::new(3, more_words); + let w = PgpWordList::new(3, more_words); for _ in 0..20 { assert!(expected3.contains(&w.choose_words())); } } - - #[test] - fn test_default_completions() { - let w = default_wordlist(2); - let c = w.get_completions("ar"); - assert_eq!(c.len(), 2); - assert!(c.contains(&String::from("article-"))); - assert!(c.contains(&String::from("armistice-"))); - - let c = w.get_completions("armis"); - assert_eq!(c.len(), 1); - assert!(c.contains(&String::from("armistice-"))); - - let c = w.get_completions("armistice-"); - assert_eq!(c.len(), 256); - - let c = w.get_completions("armistice-ba"); - assert_eq!( - c, - vec![ - "armistice-baboon", - "armistice-backfield", - "armistice-backward", - "armistice-banjo", - ] - ); - - let w = default_wordlist(3); - let c = w.get_completions("armistice-ba"); - assert_eq!( - c, - vec![ - "armistice-baboon-", - "armistice-backfield-", - "armistice-backward-", - "armistice-banjo-", - ] - ); - - let w = default_wordlist(4); - let c = w.get_completions("armistice-baboon"); - assert_eq!(c, vec!["armistice-baboon-"]); - } } diff --git a/src/lib.rs b/src/lib.rs index 3be0651c..e7cbc158 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,7 @@ #[macro_use] mod util; -mod core; +pub mod core; #[cfg(feature = "dilation")] pub mod dilated_transfer; #[cfg(feature = "dilation")] @@ -41,8 +41,9 @@ pub mod uri; pub use crate::core::{ key::{GenericKey, Key, KeyPurpose, WormholeKey}, - rendezvous, AppConfig, AppID, Code, MailboxConnection, Mood, Nameplate, Wormhole, - WormholeError, + rendezvous, + wordlist::PgpWordList, + AppConfig, AppID, Code, MailboxConnection, Mood, Nameplate, Wormhole, WormholeError, }; #[cfg(feature = "dilation")] diff --git a/src/transfer/cancel.rs b/src/transfer/cancel.rs index bb99ddbd..4837b0f9 100644 --- a/src/transfer/cancel.rs +++ b/src/transfer/cancel.rs @@ -49,7 +49,11 @@ macro_rules! with_cancel_wormhole { ($wormhole:ident, run = $run:expr, $cancel:expr, ret_cancel = $ret_cancel:expr $(,)?) => {{ let run = Box::pin($run); let result = cancel::cancellable_2(run, $cancel).await; - let Some((transit, wormhole, cancel)) = cancel::handle_run_result_noclose($wormhole, result).await? else { return Ok($ret_cancel); }; + let Some((transit, wormhole, cancel)) = + cancel::handle_run_result_noclose($wormhole, result).await? + else { + return Ok($ret_cancel); + }; (transit, wormhole, cancel) }}; } @@ -238,8 +242,8 @@ pub async fn handle_run_result_transit( */ loop { let Ok(msg) = transit.receive_record().await else { - break; - }; + break; + }; match parse_message(&msg) { Ok(None) => continue, Ok(Some(err)) => { diff --git a/src/transit/crypto.rs b/src/transit/crypto.rs index 1b64c90c..475316a5 100644 --- a/src/transit/crypto.rs +++ b/src/transit/crypto.rs @@ -263,7 +263,7 @@ impl TransitCryptoInit for NoiseInit { builder.set_is_initiator(true); builder.build_handshake_state() }; - handshake.push_psk(&*self.key); + handshake.push_psk(&self.key); // → psk, e socket @@ -279,7 +279,7 @@ impl TransitCryptoInit for NoiseInit { // ← "" let peer_confirmation_message = rx.decrypt_vec(&socket.read_transit_message().await?)?; ensure!( - peer_confirmation_message.len() == 0, + peer_confirmation_message.is_empty(), TransitHandshakeError::HandshakeFailed ); @@ -330,7 +330,7 @@ impl TransitCryptoInit for NoiseInit { builder.set_is_initiator(false); builder.build_handshake_state() }; - handshake.push_psk(&*self.key); + handshake.push_psk(&self.key); // ← psk, e handshake.read_message(&socket.read_transit_message().await?, &mut [])?; @@ -350,7 +350,7 @@ impl TransitCryptoInit for NoiseInit { // ← "" let peer_confirmation_message = rx.decrypt_vec(&socket.read_transit_message().await?)?; ensure!( - peer_confirmation_message.len() == 0, + peer_confirmation_message.is_empty(), TransitHandshakeError::HandshakeFailed ); diff --git a/src/uri.rs b/src/uri.rs index 7713700e..e071401d 100644 --- a/src/uri.rs +++ b/src/uri.rs @@ -125,13 +125,13 @@ impl std::str::FromStr for WormholeTransferUri { impl From<&WormholeTransferUri> for url::Url { fn from(val: &WormholeTransferUri) -> Self { let mut url = url::Url::parse("wormhole-transfer:").unwrap(); - url.set_path(&*val.code); + url.set_path(&val.code); /* Only do this if there are any query parameteres at all, otherwise the URL will have an ugly trailing '?'. */ if val.rendezvous_server.is_some() || val.is_leader { let mut query = url.query_pairs_mut(); query.clear(); if let Some(rendezvous_server) = val.rendezvous_server.as_ref() { - query.append_pair("rendezvous", &rendezvous_server.to_string()); + query.append_pair("rendezvous", rendezvous_server.as_ref()); } if val.is_leader { query.append_pair("role", "leader"); diff --git a/src/util.rs b/src/util.rs index 6f3fd2e8..6a172af2 100644 --- a/src/util.rs +++ b/src/util.rs @@ -17,7 +17,7 @@ pub struct DisplayBytes<'a>(pub &'a [u8]); impl std::fmt::Display for DisplayBytes<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let hex_decode = hex::decode(&self.0); + let hex_decode = hex::decode(self.0); let (string, hex_param) = match hex_decode.as_ref().map(Vec::as_slice) { Ok(decoded_hex) => (decoded_hex, "hex-encoded "), Err(_) => (self.0, ""),