From 3d6fabe2704f337545ac0b488146f364ab0fcf76 Mon Sep 17 00:00:00 2001 From: Arthur Welf Date: Thu, 28 Nov 2024 12:02:31 +0100 Subject: [PATCH] rusk-wallet: Refactor Wallet struct to make impossible states impossible Fix merge conflict --- rusk-wallet/src/bin/command.rs | 6 +- rusk-wallet/src/bin/interactive.rs | 48 ++-- rusk-wallet/src/bin/main.rs | 57 ++--- rusk-wallet/src/dat.rs | 94 +------ rusk-wallet/src/error.rs | 3 + rusk-wallet/src/lib.rs | 3 +- rusk-wallet/src/wallet.rs | 89 +++---- rusk-wallet/src/wallet/file.rs | 262 ++++++++++++++----- rusk-wallet/src/wallet/file_service.rs | 341 +++++++++++++++++++++++++ rusk-wallet/src/wallet/transaction.rs | 2 +- 10 files changed, 642 insertions(+), 263 deletions(-) create mode 100644 rusk-wallet/src/wallet/file_service.rs diff --git a/rusk-wallet/src/bin/command.rs b/rusk-wallet/src/bin/command.rs index 3e02333264..b34409758f 100644 --- a/rusk-wallet/src/bin/command.rs +++ b/rusk-wallet/src/bin/command.rs @@ -21,14 +21,14 @@ use rusk_wallet::gas::{ DEFAULT_PRICE, MIN_PRICE_DEPLOYMENT, }; use rusk_wallet::{ - Address, Error, Profile, Wallet, EPOCH, MAX_CONTRACT_INIT_ARG_SIZE, - MAX_PROFILES, + Address, Error, Profile, Wallet, WalletPath, EPOCH, + MAX_CONTRACT_INIT_ARG_SIZE, MAX_PROFILES, }; use wallet_core::BalanceInfo; use crate::io::prompt; use crate::settings::Settings; -use crate::{WalletFile, WalletPath}; +use crate::WalletFile; /// Commands that can be run against the Dusk wallet #[allow(clippy::large_enum_variant)] diff --git a/rusk-wallet/src/bin/interactive.rs b/rusk-wallet/src/bin/interactive.rs index 5e371ade15..59357c645f 100644 --- a/rusk-wallet/src/bin/interactive.rs +++ b/rusk-wallet/src/bin/interactive.rs @@ -10,9 +10,13 @@ use std::fmt::Display; use bip39::{Language, Mnemonic, MnemonicType}; use inquire::{InquireError, Select}; + use rusk_wallet::currency::Dusk; use rusk_wallet::dat::{DatFileVersion, LATEST_VERSION}; -use rusk_wallet::{Address, Error, Profile, Wallet, WalletPath, MAX_PROFILES}; +use rusk_wallet::{ + Address, Error, Profile, SecureWalletFile, Wallet, WalletFilePath, + WalletPath, MAX_PROFILES, +}; use crate::io::{self, prompt}; use crate::settings::Settings; @@ -138,6 +142,9 @@ async fn profile_idx( match menu_profile(wallet)? { ProfileSelect::Index(index, _) => Ok(index), ProfileSelect::New => { + // get the wallet file + let file = wallet.file().clone().ok_or(Error::WalletFileMissing)?; + if wallet.profiles().len() >= MAX_PROFILES { println!( "Cannot create more profiles, this wallet only supports up to {MAX_PROFILES} profiles" @@ -147,24 +154,22 @@ async fn profile_idx( } let profile_idx = wallet.add_profile(); - let file_version = wallet.get_file_version()?; let password = &settings.password; // if the version file is old, ask for password and save as // latest dat file - if file_version.is_old() { + if file.is_old() { let pwd = prompt::request_auth( "Updating your wallet data file, please enter your wallet password ", password, DatFileVersion::RuskBinaryFileFormat(LATEST_VERSION), )?; - // UNWRAP: we can safely unwrap here because we know the file is - // not None since we've checked the file version - wallet.save_to(WalletFile { - path: wallet.file().clone().unwrap().path, + wallet.save_to(WalletFile::new( + file.path().clone(), pwd, - })?; + DatFileVersion::RuskBinaryFileFormat(LATEST_VERSION), + ))?; } else { // else just save wallet.save()?; @@ -231,8 +236,10 @@ pub(crate) async fn load_wallet( settings: &Settings, file_version: Result, ) -> anyhow::Result> { - let wallet_found = - wallet_path.inner().exists().then(|| wallet_path.clone()); + let wallet_found = wallet_path + .wallet_path() + .exists() + .then(|| wallet_path.clone()); let password = &settings.password; @@ -247,10 +254,11 @@ pub(crate) async fn load_wallet( password, file_version, )?; - match Wallet::from_file(WalletFile { - path: path.clone(), + match Wallet::from_file(WalletFile::new( + path.clone(), pwd, - }) { + file_version, + )) { Ok(wallet) => break wallet, Err(_) if attempt > 2 => { Err(Error::AttemptsExhausted)?; @@ -277,7 +285,11 @@ pub(crate) async fn load_wallet( // create and store the wallet let mut w = Wallet::new(mnemonic)?; let path = wallet_path.clone(); - w.save_to(WalletFile { path, pwd })?; + w.save_to(WalletFile::new( + path, + pwd, + DatFileVersion::RuskBinaryFileFormat(LATEST_VERSION), + ))?; w } MainMenu::Recover => { @@ -292,7 +304,11 @@ pub(crate) async fn load_wallet( // create and store the recovered wallet let mut w = Wallet::new(phrase)?; let path = wallet_path.clone(); - w.save_to(WalletFile { path, pwd })?; + w.save_to(WalletFile::new( + path, + pwd, + DatFileVersion::RuskBinaryFileFormat(LATEST_VERSION), + ))?; w } MainMenu::Exit => std::process::exit(0), @@ -493,7 +509,7 @@ impl Display for MainMenu { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { MainMenu::Load(path) => { - write!(f, "Load wallet from {}", path.wallet.display()) + write!(f, "Load wallet from {}", path.wallet_path().display()) } MainMenu::Create => write!(f, "Create a new wallet"), MainMenu::Recover => { diff --git a/rusk-wallet/src/bin/main.rs b/rusk-wallet/src/bin/main.rs index 445d99d521..1864ad8d4d 100644 --- a/rusk-wallet/src/bin/main.rs +++ b/rusk-wallet/src/bin/main.rs @@ -11,6 +11,7 @@ mod io; mod settings; pub(crate) use command::{Command, RunResult}; +use rusk_wallet::{WalletFilePath, WalletPath}; use std::fs::{self, File}; use std::io::Write; @@ -19,35 +20,20 @@ use bip39::{Language, Mnemonic, MnemonicType}; use clap::Parser; use inquire::InquireError; use rocksdb::ErrorKind; -use rusk_wallet::currency::Dusk; -use rusk_wallet::dat::{self, LATEST_VERSION}; -use rusk_wallet::{ - Error, GraphQL, Profile, SecureWalletFile, Wallet, WalletPath, EPOCH, -}; use tracing::{error, info, warn, Level}; use crate::command::TransactionHistory; use crate::settings::{LogFormat, Settings}; +use rusk_wallet::{ + currency::Dusk, + dat::{self, LATEST_VERSION}, + Error, GraphQL, Profile, SecureWalletFile, Wallet, WalletFile, EPOCH, +}; + use config::Config; use io::{prompt, status, WalletArgs}; -#[derive(Debug, Clone)] -pub(crate) struct WalletFile { - path: WalletPath, - pwd: Vec, -} - -impl SecureWalletFile for WalletFile { - fn path(&self) -> &WalletPath { - &self.path - } - - fn pwd(&self) -> &[u8] { - &self.pwd - } -} - #[tokio::main(flavor = "multi_thread")] async fn main() -> anyhow::Result<()> { if let Err(err) = exec().await { @@ -139,7 +125,7 @@ async fn exec() -> anyhow::Result<()> { // prepare wallet path let mut wallet_path = - WalletPath::from(wallet_dir.as_path().join("wallet.dat")); + WalletPath::try_from(wallet_dir.as_path().join("wallet.dat"))?; // load configuration (or use default) let cfg = Config::load(&wallet_dir)?; @@ -193,6 +179,7 @@ async fn exec() -> anyhow::Result<()> { return Ok(()); }; + // get the wallet file version let file_version = dat::read_file_version(&wallet_path); // get our wallet ready @@ -232,10 +219,7 @@ async fn exec() -> anyhow::Result<()> { // create wallet let mut w = Wallet::new(mnemonic)?; - w.save_to(WalletFile { - path: wallet_path, - pwd, - })?; + w.save_to(WalletFile::new(wallet_path, pwd, file_version?))?; w } @@ -253,10 +237,11 @@ async fn exec() -> anyhow::Result<()> { file_version, )?; - let w = Wallet::from_file(WalletFile { - path: file.clone(), - pwd: pwd.clone(), - })?; + let w = Wallet::from_file(WalletFile::new( + file.clone(), + pwd.clone(), + file_version, + ))?; (w, pwd) } @@ -279,10 +264,7 @@ async fn exec() -> anyhow::Result<()> { } }; - w.save_to(WalletFile { - path: wallet_path, - pwd, - })?; + w.save_to(WalletFile::new(wallet_path, pwd, file_version?))?; w } @@ -297,10 +279,11 @@ async fn exec() -> anyhow::Result<()> { file_version, )?; - Wallet::from_file(WalletFile { - path: wallet_path, + Wallet::from_file(WalletFile::new( + wallet_path, pwd, - })? + file_version, + ))? } }, }; diff --git a/rusk-wallet/src/dat.rs b/rusk-wallet/src/dat.rs index c0a243e46d..d34005d46d 100644 --- a/rusk-wallet/src/dat.rs +++ b/rusk-wallet/src/dat.rs @@ -9,10 +9,7 @@ use std::fs; use std::io::Read; -use wallet_core::Seed; - -use crate::crypto::decrypt; -use crate::{Error, WalletPath}; +use crate::{Error, WalletFilePath, WalletPath}; /// Binary prefix for old Dusk wallet files pub const OLD_MAGIC: u32 = 0x1d0c15; @@ -38,89 +35,6 @@ pub enum DatFileVersion { RuskBinaryFileFormat(Version), } -impl DatFileVersion { - /// Checks if the file version is older than the latest Rust Binary file - /// format - pub fn is_old(&self) -> bool { - matches!(self, Self::Legacy | Self::OldWalletCli(_)) - } -} - -/// Make sense of the payload and return it -pub(crate) fn get_seed_and_address( - file: DatFileVersion, - mut bytes: Vec, - pwd: &[u8], -) -> Result<(Seed, u8), Error> { - match file { - DatFileVersion::Legacy => { - if bytes[1] == 0 && bytes[2] == 0 { - bytes.drain(..3); - } - - bytes = decrypt(&bytes, pwd)?; - - // get our seed - let seed = bytes[..] - .try_into() - .map_err(|_| Error::WalletFileCorrupted)?; - - Ok((seed, 1)) - } - DatFileVersion::OldWalletCli((major, minor, _, _, _)) => { - bytes.drain(..5); - - let result: Result<(Seed, u8), Error> = match (major, minor) { - (1, 0) => { - let content = decrypt(&bytes, pwd)?; - let buff = &content[..]; - - let seed = buff - .try_into() - .map_err(|_| Error::WalletFileCorrupted)?; - - Ok((seed, 1)) - } - (2, 0) => { - let content = decrypt(&bytes, pwd)?; - let buff = &content[..]; - - // extract seed - let seed = buff - .try_into() - .map_err(|_| Error::WalletFileCorrupted)?; - - // extract addresses count - Ok((seed, buff[0])) - } - _ => Err(Error::UnknownFileVersion(major, minor)), - }; - - result - } - DatFileVersion::RuskBinaryFileFormat(_) => { - let rest = bytes.get(12..(12 + 96)); - if let Some(rest) = rest { - let content = decrypt(rest, pwd)?; - - if let Some(seed_buff) = content.get(0..65) { - let seed = seed_buff[0..64] - .try_into() - .map_err(|_| Error::WalletFileCorrupted)?; - - let count = &seed_buff[64..65]; - - Ok((seed, count[0])) - } else { - Err(Error::WalletFileCorrupted) - } - } else { - Err(Error::WalletFileCorrupted) - } - } - } -} - /// From the first 12 bytes of the file (header), we check version /// /// https://github.com/dusk-network/rusk/wiki/Binary-File-Format/#header @@ -193,8 +107,10 @@ pub(crate) fn check_version( /// Read the first 12 bytes of the dat file and get the file version from /// there -pub fn read_file_version(file: &WalletPath) -> Result { - let path = &file.wallet; +pub fn read_file_version( + wallet_file_path: &WalletPath, +) -> Result { + let path = &wallet_file_path.wallet_path(); // make sure file exists if !path.is_file() { diff --git a/rusk-wallet/src/error.rs b/rusk-wallet/src/error.rs index 8b39a919e3..456cfabcd8 100644 --- a/rusk-wallet/src/error.rs +++ b/rusk-wallet/src/error.rs @@ -153,6 +153,9 @@ pub enum Error { /// Contract file location not found #[error("Invalid WASM contract path provided")] InvalidWasmContractPath, + /// Invalid wallet file path + #[error("Invalid wallet file path")] + InvalidWalletFilePath, /// Invalid environment variable value #[error("Invalid environment variable value {0}")] InvalidEnvVar(String), diff --git a/rusk-wallet/src/lib.rs b/rusk-wallet/src/lib.rs index 1529288a98..ea9145d7fb 100644 --- a/rusk-wallet/src/lib.rs +++ b/rusk-wallet/src/lib.rs @@ -31,7 +31,8 @@ pub use error::Error; pub use gql::{BlockTransaction, GraphQL}; pub use rues::RuesHttpClient; pub use wallet::{ - Address, DecodedNote, Profile, SecureWalletFile, Wallet, WalletPath, + Address, DecodedNote, Profile, SecureWalletFile, Wallet, WalletFile, + WalletFilePath, WalletPath, }; use execution_core::stake::StakeData; diff --git a/rusk-wallet/src/wallet.rs b/rusk-wallet/src/wallet.rs index ee2999bcc5..fa4d6f0a8f 100644 --- a/rusk-wallet/src/wallet.rs +++ b/rusk-wallet/src/wallet.rs @@ -6,10 +6,12 @@ mod address; mod file; +mod file_service; mod transaction; pub use address::{Address, Profile}; -pub use file::{SecureWalletFile, WalletPath}; +pub use file::{WalletFile, WalletPath}; +pub use file_service::{SecureWalletFile, WalletFilePath}; use std::fmt::Debug; use std::fs; @@ -38,8 +40,7 @@ use crate::clients::State; use crate::crypto::encrypt; use crate::currency::Dusk; use crate::dat::{ - self, version_bytes, DatFileVersion, FILE_TYPE, LATEST_VERSION, MAGIC, - RESERVED, + version_bytes, DatFileVersion, FILE_TYPE, LATEST_VERSION, MAGIC, RESERVED, }; use crate::gas::MempoolGasPrices; use crate::rues::RuesHttpClient; @@ -59,12 +60,11 @@ use crate::Error; /// A wallet must connect to the network using a [`RuskEndpoint`] in order to be /// able to perform common operations such as checking balance, transfernig /// funds, or staking Dusk. -pub struct Wallet { +pub struct Wallet { profiles: Vec, state: Option, store: LocalStore, file: Option, - file_version: Option, } impl Wallet { @@ -106,7 +106,6 @@ impl Wallet { state: None, store: LocalStore::from(seed_bytes), file: None, - file_version: None, }) } else { Err(Error::InvalidMnemonicPhrase) @@ -115,25 +114,11 @@ impl Wallet { /// Loads wallet given a session pub fn from_file(file: F) -> Result { - let path = file.path(); - let pwd = file.pwd(); - - // make sure file exists - let pb = path.inner().clone(); - if !pb.is_file() { - return Err(Error::WalletFileMissing); - } - - // attempt to load and decode wallet - let bytes = fs::read(&pb)?; - - let file_version = dat::check_version(bytes.get(0..12))?; - - let (seed, address_count) = - dat::get_seed_and_address(file_version, bytes, pwd)?; + // Get the seed and address count from the file + let (seed, address_count) = file.get_seed_and_address()?; // return early if its legacy - if let DatFileVersion::Legacy = file_version { + if let DatFileVersion::Legacy = file.version() { // Generate the default address at index 0 let profiles = vec![Profile { shielded_addr: derive_phoenix_pk(&seed, 0), @@ -146,7 +131,6 @@ impl Wallet { store: LocalStore::from(seed), state: None, file: Some(file), - file_version: Some(DatFileVersion::Legacy), }); } @@ -163,7 +147,6 @@ impl Wallet { store: LocalStore::from(seed), state: None, file: Some(file), - file_version: Some(file_version), }) } @@ -196,7 +179,7 @@ impl Wallet { content.extend_from_slice(&payload); // write the content to file - fs::write(&f.path().wallet, content)?; + fs::write(f.wallet_path(), content)?; Ok(()) } None => Err(Error::WalletFileMissing), @@ -399,15 +382,10 @@ impl Wallet { /// get cache database path pub(crate) fn cache_path(&self) -> Result { - let cache_dir = { - if let Some(file) = &self.file { - file.path().cache_dir() - } else { - return Err(Error::WalletFileMissing); - } - }; - - Ok(cache_dir) + match self.file() { + Some(file) => Ok(file.cache_dir()), + None => Err(Error::WalletFileMissing), + } } /// Returns the shielded key for a given index. @@ -573,10 +551,8 @@ impl Wallet { /// Return the dat file version from memory or by reading the file /// In order to not read the file version more than once per execution pub fn get_file_version(&self) -> Result { - if let Some(file_version) = self.file_version { - Ok(file_version) - } else if let Some(file) = &self.file { - Ok(dat::read_file_version(file.path())?) + if let Some(file) = &self.file { + Ok(file.version()) } else { Err(Error::WalletFileMissing) } @@ -659,28 +635,14 @@ mod base64 { #[cfg(test)] mod tests { + use std::fs::File; + use tempfile::tempdir; use super::*; const TEST_ADDR: &str = "2w7fRQW23Jn9Bgm1GQW9eC2bD9U883dAwqP7HAr2F8g1syzPQaPYrxSyyVZ81yDS5C1rv9L8KjdPBsvYawSx3QCW"; - #[derive(Debug, Clone)] - struct WalletFile { - path: WalletPath, - pwd: Vec, - } - - impl SecureWalletFile for WalletFile { - fn path(&self) -> &WalletPath { - &self.path - } - - fn pwd(&self) -> &[u8] { - &self.pwd - } - } - #[test] fn wallet_basics() -> Result<(), Box> { // create a wallet from a mnemonic phrase @@ -715,15 +677,23 @@ mod tests { fn save_and_load() -> Result<(), Box> { // prepare a tmp path let dir = tempdir()?; - let path = dir.path().join("my_wallet.dat"); - let path = WalletPath::from(path); + let tmp_file_path = dir.path().join("my_wallet.dat"); + // we need to create a real file because the `WalletFile::try_from` + // checks if the file exists. The file will be deleted when the test + // ends. + let tmp_file = File::create(&tmp_file_path)?; + let path = WalletPath::try_from(tmp_file_path.clone()).unwrap(); // we'll need a password too let pwd = blake3::hash("mypassword".as_bytes()).as_bytes().to_vec(); // create and save let mut wallet: Wallet = Wallet::new("uphold stove tennis fire menu three quick apple close guilt poem garlic volcano giggle comic")?; - let file = WalletFile { path, pwd }; + let file = WalletFile::new( + path, + pwd, + DatFileVersion::RuskBinaryFileFormat(LATEST_VERSION), + ); wallet.save_to(file.clone())?; // load from file and check @@ -733,6 +703,9 @@ mod tests { let loaded_addr = loaded_wallet.default_shielded_address(); assert!(original_addr.eq(&loaded_addr)); + drop(tmp_file); + dir.close()?; + Ok(()) } } diff --git a/rusk-wallet/src/wallet/file.rs b/rusk-wallet/src/wallet/file.rs index 0e3fca78dd..51375140a6 100644 --- a/rusk-wallet/src/wallet/file.rs +++ b/rusk-wallet/src/wallet/file.rs @@ -4,107 +4,155 @@ // // Copyright (c) DUSK NETWORK. All rights reserved. -use std::fmt; +use std::fmt::{Debug, Display}; +use std::hash::Hash; use std::path::{Path, PathBuf}; use std::str::FromStr; -/// Provides access to a secure wallet file -pub trait SecureWalletFile { - /// Returns the path - fn path(&self) -> &WalletPath; - /// Returns the hashed password - fn pwd(&self) -> &[u8]; +use crate::dat::DatFileVersion; +use crate::{Error, SecureWalletFile}; + +use super::file_service::WalletFilePath; + +/// Wallet file structure that contains the path of the wallet file, the hashed +/// password, and the file version +#[derive(Debug, Clone)] +pub struct WalletFile { + path: WalletPath, + pwd: Vec, + file_version: DatFileVersion, +} + +impl SecureWalletFile for WalletFile { + type PathBufWrapper = WalletPath; + + fn path(&self) -> &WalletPath { + &self.path + } + + fn path_mut(&mut self) -> &mut WalletPath { + &mut self.path + } + + fn pwd(&self) -> &[u8] { + &self.pwd + } + + fn version(&self) -> DatFileVersion { + self.file_version + } +} + +impl WalletFile { + /// Create a new wallet file + pub fn new( + path: WalletPath, + pwd: Vec, + file_version: DatFileVersion, + ) -> Self { + Self { + path, + pwd, + file_version, + } + } } /// Wrapper around `PathBuf` for wallet paths #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub struct WalletPath { - /// Path of the wallet file - pub wallet: PathBuf, + /// Path to the wallet file + wallet: PathBuf, /// Directory of the profile - pub profile_dir: PathBuf, + profile_dir: PathBuf, /// Name of the network - pub network: Option, + network: Option, +} + +impl WalletFilePath for WalletPath { + fn wallet_path(&self) -> &PathBuf { + &self.wallet + } + + fn wallet_path_mut(&mut self) -> &mut PathBuf { + &mut self.wallet + } + + fn profile_dir(&self) -> &PathBuf { + &self.profile_dir + } + + fn network(&self) -> Option<&String> { + self.network.as_ref() + } + + fn network_mut(&mut self) -> &mut Option { + &mut self.network + } } impl WalletPath { /// Create wallet path from the path of "wallet.dat" file. The wallet.dat /// file should be located in the profile folder, this function also /// generates the profile folder from the passed argument - pub fn new(wallet: &Path) -> Self { - let wallet = wallet.to_path_buf(); + pub fn new(wallet_file_path: &Path) -> Result { + let wallet = wallet_file_path.to_path_buf(); // The wallet should be in the profile folder let mut profile_dir = wallet.clone(); - profile_dir.pop(); + let is_valid_dir = profile_dir.pop(); - Self { + if !is_valid_dir { + return Err(Error::InvalidWalletFilePath); + } + + Ok(Self { wallet, profile_dir, network: None, - } + }) } +} - /// Returns the filename of this path - pub fn name(&self) -> Option { - // extract the name - let name = self.wallet.file_stem()?.to_str()?; - Some(String::from(name)) - } +impl TryFrom for WalletPath { + type Error = Error; - /// Returns current directory for this path - pub fn dir(&self) -> Option { - self.wallet.parent().map(PathBuf::from) - } + fn try_from(p: PathBuf) -> Result { + let p = p.to_path_buf(); - /// Returns a reference to the `PathBuf` holding the path - pub fn inner(&self) -> &PathBuf { - &self.wallet - } + let is_valid = + p.try_exists().map_err(|_| Error::InvalidWalletFilePath)? + && p.is_file(); - /// Sets the network name for different cache locations. - /// e.g, devnet, testnet, etc. - pub fn set_network_name(&mut self, network: Option) { - self.network = network; + if !is_valid { + return Err(Error::InvalidWalletFilePath); + } + + Self::new(&p) } +} - /// Generates dir for cache based on network specified - pub fn cache_dir(&self) -> PathBuf { - let mut cache = self.profile_dir.clone(); +impl TryFrom<&Path> for WalletPath { + type Error = Error; - if let Some(network) = &self.network { - cache.push(format!("cache_{network}")); - } else { - cache.push("cache"); - } + fn try_from(p: &Path) -> Result { + let p = p.to_path_buf(); - cache + Self::try_from(p) } } impl FromStr for WalletPath { - type Err = crate::Error; + type Err = Error; fn from_str(s: &str) -> Result { let p = Path::new(s); - Ok(Self::new(p)) + Self::try_from(p) } } -impl From for WalletPath { - fn from(p: PathBuf) -> Self { - Self::new(&p) - } -} - -impl From<&Path> for WalletPath { - fn from(p: &Path) -> Self { - Self::new(p) - } -} - -impl fmt::Display for WalletPath { +impl Display for WalletPath { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, @@ -115,3 +163,101 @@ impl fmt::Display for WalletPath { ) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::fs::File; + use tempfile::tempdir; + + #[test] + fn test_wallet_path_creation() -> Result<(), Error> { + let dir = tempdir()?; + let wallet_file = dir.path().join("wallet.dat"); + let file = File::create(&wallet_file)?; + + let wallet_path = WalletPath::new(&wallet_file)?; + + assert_eq!(wallet_path.wallet_path(), &wallet_file, "wallet path is not correct for WalletPath created by WalletPath::new method"); + assert_eq!(wallet_path.profile_dir(), dir.path(), "profile dir is not correct for WalletPath created by WalletPath::new method"); + assert_eq!(wallet_path.network(), None, "network is not correct for WalletPath created by WalletPath::new method"); + + // try_from(PathBuf) + let wallet_path = WalletPath::try_from(wallet_file.clone())?; + + assert_eq!(wallet_path.wallet_path(), &wallet_file, "wallet path is not correct for WalletPath created by WalletPath::try_from(PathBuf) method"); + assert_eq!(wallet_path.profile_dir(), dir.path(), "profile dir is not correct for WalletPath created by WalletPath::try_from(PathBuf) method"); + assert_eq!(wallet_path.network(), None, "network is not correct for WalletPath created by WalletPath::try_from(PathBuf) method"); + + // try_from(&Path) + let wallet_path = WalletPath::try_from(wallet_file.as_path())?; + + assert_eq!(wallet_path.wallet_path(), &wallet_file, "wallet path is not correct for WalletPath created by WalletPath::try_from(&Path) method"); + assert_eq!(wallet_path.profile_dir(), dir.path(), "profile dir is not correct for WalletPath created by WalletPath::try_from(&Path) method"); + assert_eq!(wallet_path.network(), None, "network is not correct for WalletPath created by WalletPath::try_from(&Path) method"); + + // from_str + let wallet_path = WalletPath::from_str(wallet_file.to_str().unwrap())?; + + assert_eq!(wallet_path.wallet_path(), &wallet_file, "wallet path is not correct for WalletPath created by WalletPath::from_str method"); + assert_eq!(wallet_path.profile_dir(), dir.path(), "profile dir is not correct for WalletPath created by WalletPath::from_str method"); + assert_eq!(wallet_path.network(), None, "network is not correct for WalletPath created by WalletPath::from_str method"); + + // the path is not a file + let wallet_path = WalletPath::try_from(dir.path()); + + assert!( + wallet_path.is_err(), + "WalletPath::try_from should return an error when the path is not a file" + ); + + // the path does not exist + let wallet_path = WalletPath::from_str("invalid_path"); + + assert!( + wallet_path.is_err(), + "WalletPath::try_from should return an error when the path does not exist" + ); + + drop(file); + dir.close()?; + + Ok(()) + } + + #[test] + fn test_wallet_file_creation() -> Result<(), Error> { + let dir = tempdir()?; + let wallet_file = dir.path().join("wallet.dat"); + let file = File::create(&wallet_file)?; + + let path = WalletPath::new(&wallet_file)?; + let pwd = vec![1, 2, 3, 4]; + let file_version = + DatFileVersion::RuskBinaryFileFormat((1, 0, 0, 0, false)); + + let wallet_file = + WalletFile::new(path.clone(), pwd.clone(), file_version); + + assert_eq!( + wallet_file.path(), + &path, + "path is not correct for WalletFile" + ); + assert_eq!( + wallet_file.pwd(), + &pwd, + "pwd is not correct for WalletFile" + ); + assert_eq!( + wallet_file.version(), + file_version, + "file_version is not correct for WalletFile" + ); + + drop(file); + dir.close()?; + + Ok(()) + } +} diff --git a/rusk-wallet/src/wallet/file_service.rs b/rusk-wallet/src/wallet/file_service.rs new file mode 100644 index 0000000000..700b0deda3 --- /dev/null +++ b/rusk-wallet/src/wallet/file_service.rs @@ -0,0 +1,341 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. +// +// Copyright (c) DUSK NETWORK. All rights reserved. + +use std::fmt::Debug; +use std::hash::Hash; +use std::path::PathBuf; + +use wallet_core::Seed; + +use crate::crypto::decrypt; +use crate::dat::DatFileVersion; +use crate::Error; + +/// Provides access to a secure wallet file +pub trait SecureWalletFile: Debug + Send + Sync + Clone { + /// The type of the path buffer wrapper + type PathBufWrapper: WalletFilePath + Hash + Eq + PartialEq + Debug + Clone; + + // Methods to implement =================================================== + + /// Returns the path + fn path(&self) -> &Self::PathBufWrapper; + /// Return the mutable path + fn path_mut(&mut self) -> &mut Self::PathBufWrapper; + /// Returns the hashed password + fn pwd(&self) -> &[u8]; + /// Returns the file version + fn version(&self) -> DatFileVersion; + + // Automatically implemented methods ======================================= + + /// Returns the path of the wallet file + fn wallet_path(&self) -> &PathBuf { + self.path().wallet_path() + } + + /// Returns the directory of the profile + fn profile_dir(&self) -> &PathBuf { + self.path().profile_dir() + } + + /// Returns the network name for different cache locations + /// e.g, devnet, testnet, etc. + fn network(&self) -> Option<&String> { + self.path().network() + } + + /// Sets the network name for different cache locations. + /// e.g, devnet, testnet, etc. + fn set_network_name(&mut self, network: Option) { + self.path_mut().set_network_name(network); + } + + /// Returns the filename of this file + fn name(&self) -> Option { + self.path().name() + } + + /// Returns dir for cache based on network specified + fn cache_dir(&self) -> PathBuf { + self.path().cache_dir() + } + + /// Checks if the file version is older than the latest Rust Binary file + /// format + fn is_old(&self) -> bool { + let version = self.version(); + matches!( + version, + DatFileVersion::Legacy | DatFileVersion::OldWalletCli(_) + ) + } + + /// Get the seed and address from the file + fn get_seed_and_address(&self) -> Result<(Seed, u8), Error> { + let file_version = self.version(); + let pwd = self.pwd(); + let wallet_path = self.wallet_path().clone(); + + // Make sure the wallet file exists + if !wallet_path.is_file() { + return Err(Error::WalletFileMissing); + } + + // Load the wallet file + let mut bytes = std::fs::read(wallet_path)?; + + match file_version { + DatFileVersion::Legacy => { + if bytes[1] == 0 && bytes[2] == 0 { + bytes.drain(..3); + } + + bytes = decrypt(&bytes, pwd)?; + + let seed = bytes[..] + .try_into() + .map_err(|_| Error::WalletFileCorrupted)?; + + Ok((seed, 1)) + } + DatFileVersion::OldWalletCli((major, minor, _, _, _)) => { + bytes.drain(..5); + + let content = decrypt(&bytes, pwd)?; + let buff = &content[..]; + + let seed = + buff.try_into().map_err(|_| Error::WalletFileCorrupted)?; + + match (major, minor) { + (1, 0) => Ok((seed, 1)), + (2, 0) => Ok((seed, buff[0])), + _ => Err(Error::UnknownFileVersion(major, minor)), + } + } + DatFileVersion::RuskBinaryFileFormat(_) => { + let rest = bytes.get(12..(12 + 96)); + + if let Some(rest) = rest { + let content = decrypt(rest, pwd)?; + + if let Some(seed_buf) = content.get(0..65) { + let seed = seed_buf[0..64] + .try_into() + .map_err(|_| Error::WalletFileCorrupted)?; + + let addr_count = &seed_buf[64..65]; + + Ok((seed, addr_count[0])) + } else { + Err(Error::WalletFileCorrupted) + } + } else { + Err(Error::WalletFileCorrupted) + } + } + } + } +} + +/// Provides access to the wallet file path, profile directory and network name, +/// and implements by default other useful methods +pub trait WalletFilePath { + // Methods to implement =================================================== + + /// Returns the path of the wallet file + fn wallet_path(&self) -> &PathBuf; + /// Returns the mutable path of the wallet file + fn wallet_path_mut(&mut self) -> &mut PathBuf; + /// Returns the directory of the profile + fn profile_dir(&self) -> &PathBuf; + /// Returns the network name for different cache locations + /// e.g, devnet, testnet, etc. + fn network(&self) -> Option<&String>; + /// Returns the mutable network name + fn network_mut(&mut self) -> &mut Option; + + // Automatically implemented methods ======================================= + + /// Sets the network name for different cache locations. + /// e.g, devnet, testnet, etc. + fn set_network_name(&mut self, network: Option) { + *self.network_mut() = network; + } + + /// Returns the filename of this path + fn name(&self) -> Option { + // extract the name + let name = self.wallet_path().file_stem()?.to_str()?; + Some(String::from(name)) + } + + /// Returns dir for cache based on network specified + fn cache_dir(&self) -> PathBuf { + let mut cache = self.profile_dir().clone(); + + if let Some(network) = self.network() { + cache.push(format!("cache_{network}")); + } else { + cache.push("cache"); + } + + cache + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(PartialEq, Eq, Hash, Debug, Clone)] + struct MockWalletFilePath { + pub wallet_path: PathBuf, + pub profile_dir: PathBuf, + pub network: Option, + } + + impl WalletFilePath for MockWalletFilePath { + fn wallet_path(&self) -> &PathBuf { + &self.wallet_path + } + + fn wallet_path_mut(&mut self) -> &mut PathBuf { + &mut self.wallet_path + } + + fn profile_dir(&self) -> &PathBuf { + &self.profile_dir + } + + fn network(&self) -> Option<&String> { + self.network.as_ref() + } + + fn network_mut(&mut self) -> &mut Option { + &mut self.network + } + } + + #[derive(Debug, Clone)] + struct MockSecureWalletFile { + pub path: MockWalletFilePath, + pub pwd: Vec, + pub version: DatFileVersion, + } + + impl SecureWalletFile for MockSecureWalletFile { + type PathBufWrapper = MockWalletFilePath; + + fn path(&self) -> &Self::PathBufWrapper { + &self.path + } + + fn path_mut(&mut self) -> &mut Self::PathBufWrapper { + &mut self.path + } + + fn pwd(&self) -> &[u8] { + &self.pwd + } + + fn version(&self) -> DatFileVersion { + self.version.clone() + } + } + + #[test] + fn test_secure_wallet_file_trait_methods() -> Result<(), Error> { + let file_path = PathBuf::from("wallet.dat"); + let profile_dir = PathBuf::from("profile"); + let network = Some("devnet".to_string()); + + let pwd = vec![1, 2, 3, 4]; + let version = DatFileVersion::RuskBinaryFileFormat((1, 0, 0, 0, false)); + + let wallet_path = MockWalletFilePath { + wallet_path: file_path.clone(), + profile_dir: profile_dir.clone(), + network: network.clone(), + }; + + let mut wallet_file = MockSecureWalletFile { + path: wallet_path.clone(), + pwd: pwd.clone(), + version: version.clone(), + }; + + assert_eq!( + wallet_file.wallet_path(), + wallet_path.wallet_path(), + "wallet path is not correct for SecureWalletFile" + ); + assert_eq!( + wallet_file.profile_dir(), + &profile_dir, + "profile dir is not correct for SecureWalletFile" + ); + assert_eq!( + wallet_file.network(), + network.as_ref(), + "network is not correct for SecureWalletFile" + ); + + let network = Some("testnet".to_string()); + + wallet_file.set_network_name(network.clone()); + + assert_eq!( + wallet_file.network(), + network.as_ref(), + "network is not correct for SecureWalletFile after set_network_name" + ); + + assert_eq!( + wallet_file.name(), + Some("wallet".to_string()), + "name is not correct for SecureWalletFile" + ); + + assert_eq!( + wallet_file.cache_dir(), + PathBuf::from("profile/cache_testnet"), + "cache_dir is not correct for SecureWalletFile" + ); + + assert!( + !wallet_file.is_old(), + "is_old is not correct for SecureWalletFile" + ); + + let old_file = MockSecureWalletFile { + path: wallet_path.clone(), + pwd: pwd.clone(), + version: DatFileVersion::Legacy, + }; + + assert!( + old_file.is_old(), + "is_old is not correct for SecureWalletFile with old file" + ); + + let another_old_file = MockSecureWalletFile { + path: wallet_path.clone(), + pwd: pwd.clone(), + version: DatFileVersion::OldWalletCli((1, 0, 0, 0, false)), + }; + + assert!( + another_old_file.is_old(), + "is_old is not correct for SecureWalletFile with another old file" + ); + + // TODO: test get_seed_and_address + + Ok(()) + } +} diff --git a/rusk-wallet/src/wallet/transaction.rs b/rusk-wallet/src/wallet/transaction.rs index ae64f6951b..638e168f27 100644 --- a/rusk-wallet/src/wallet/transaction.rs +++ b/rusk-wallet/src/wallet/transaction.rs @@ -19,12 +19,12 @@ use wallet_core::transaction::{ }; use zeroize::Zeroize; -use super::file::SecureWalletFile; use super::Wallet; use crate::clients::Prover; use crate::currency::Dusk; use crate::gas::Gas; use crate::Error; +use crate::SecureWalletFile; impl Wallet { /// Transfers funds between shielded addresses.