diff --git a/.gitignore b/.gitignore index b13651d19..d7cc7b9fd 100644 --- a/.gitignore +++ b/.gitignore @@ -26,11 +26,9 @@ bld/ .idea xcuserdata/ -# Added by cargo -# -# already existing elements were commented out +# Databases +*.sqlite -#/target node_modules/ clients/python/env/ diff --git a/Cargo.lock b/Cargo.lock index a1c0dbe4d..c0b0e98a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,6 +29,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -435,6 +447,7 @@ dependencies = [ "rand", "rand_chacha", "reqwest", + "rusqlite", "rustls-platform-verifier", "schemars", "security-framework", @@ -648,6 +661,7 @@ dependencies = [ "hmac", "rand", "reqwest", + "rusqlite", "schemars", "serde", "serde_json", @@ -1594,6 +1608,18 @@ dependencies = [ "once_cell", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fancy-regex" version = "0.11.0" @@ -1872,6 +1898,18 @@ name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] [[package]] name = "heck" @@ -2267,6 +2305,16 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" +dependencies = [ + "pkg-config", + "vcpkg", +] + [[package]] name = "line-wrap" version = "0.2.0" @@ -3256,6 +3304,21 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rusqlite" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae" +dependencies = [ + "bitflags 2.5.0", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", + "uuid", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -4474,6 +4537,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.4" @@ -4848,6 +4917,26 @@ dependencies = [ "url", ] +[[package]] +name = "zerocopy" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.63", +] + [[package]] name = "zeroize" version = "1.7.0" diff --git a/crates/bitwarden-core/Cargo.toml b/crates/bitwarden-core/Cargo.toml index ba3ebbe68..9a6651960 100644 --- a/crates/bitwarden-core/Cargo.toml +++ b/crates/bitwarden-core/Cargo.toml @@ -50,6 +50,7 @@ reqwest = { version = ">=0.12.5, <0.13", features = [ "http2", "json", ], default-features = false } +rusqlite = ">=0.31.0, <0.32" schemars = { version = ">=0.8.9, <0.9", features = ["uuid1", "chrono"] } serde = { version = ">=1.0, <2.0", features = ["derive"] } serde_json = ">=1.0.96, <2.0" diff --git a/crates/bitwarden-core/src/client/client.rs b/crates/bitwarden-core/src/client/client.rs index 3ea4ae7e8..93adb1f66 100644 --- a/crates/bitwarden-core/src/client/client.rs +++ b/crates/bitwarden-core/src/client/client.rs @@ -1,13 +1,16 @@ -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use reqwest::header::{self, HeaderValue}; use super::internal::InternalClient; #[cfg(feature = "internal")] use crate::client::flags::Flags; -use crate::client::{ - client_settings::ClientSettings, - internal::{ApiConfigurations, Tokens}, +use crate::{ + client::{ + client_settings::ClientSettings, + internal::{ApiConfigurations, Tokens}, + }, + SqliteDatabase, }; /// The main struct to interact with the Bitwarden SDK. @@ -79,6 +82,7 @@ impl Client { })), external_client, encryption_settings: RwLock::new(None), + db: Arc::new(Mutex::new(SqliteDatabase::default().unwrap())), }, } } diff --git a/crates/bitwarden-core/src/client/internal.rs b/crates/bitwarden-core/src/client/internal.rs index 55122b9a1..4f2c74df5 100644 --- a/crates/bitwarden-core/src/client/internal.rs +++ b/crates/bitwarden-core/src/client/internal.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; #[cfg(any(feature = "internal", feature = "secrets"))] use bitwarden_crypto::SymmetricCryptoKey; @@ -17,7 +17,7 @@ use crate::error::Error; use crate::{ auth::renew::renew_token, error::{Result, VaultLocked}, - DeviceType, + DeviceType, SqliteDatabase, }; #[derive(Debug, Clone)] @@ -57,6 +57,8 @@ pub struct InternalClient { pub(crate) external_client: reqwest::Client, pub(super) encryption_settings: RwLock>>, + + pub db: Arc>, } impl InternalClient { diff --git a/crates/bitwarden-core/src/database/mod.rs b/crates/bitwarden-core/src/database/mod.rs new file mode 100644 index 000000000..352aebfa1 --- /dev/null +++ b/crates/bitwarden-core/src/database/mod.rs @@ -0,0 +1,32 @@ +mod sqlite; +use std::borrow::Cow; + +pub use sqlite::SqliteDatabase; +use thiserror::Error; + +use crate::MissingFieldError; + +#[derive(Debug, Error)] +pub enum DatabaseError { + #[error("Database lock")] + DatabaseLock, + + #[error("Failed to open connection to database")] + FailedToOpenConnection, + + #[error(transparent)] + Migrator(#[from] MigratorError), + + #[error(transparent)] + Rusqlite(#[from] rusqlite::Error), + #[error(transparent)] + SerdeJson(#[from] serde_json::Error), + #[error(transparent)] + MissingField(#[from] MissingFieldError), +} + +#[derive(Debug, Error)] +pub enum MigratorError { + #[error("Internal error: {0}")] + Internal(Cow<'static, str>), +} diff --git a/crates/bitwarden-core/src/database/sqlite.rs b/crates/bitwarden-core/src/database/sqlite.rs new file mode 100644 index 000000000..ca919b607 --- /dev/null +++ b/crates/bitwarden-core/src/database/sqlite.rs @@ -0,0 +1,210 @@ +use std::cmp::Ordering; + +use log::info; +use rusqlite::Connection; + +use super::{DatabaseError, MigratorError}; + +#[derive(Debug)] +pub struct SqliteDatabase { + pub conn: Connection, +} + +impl SqliteDatabase { + pub fn default() -> Result { + let conn = + Connection::open("test.sqlite").map_err(|_| DatabaseError::FailedToOpenConnection)?; + + Self::new_conn(conn) + } + + pub fn new_test() -> Self { + let conn = Connection::open_in_memory().expect("Failed to open sqlite connection"); + + Self::new_conn(conn).expect("Created test db") + } + + /// Create a new SqliteDatabase with the given connection. + /// + /// This will migrate the database to the latest version. + fn new_conn(conn: Connection) -> Result { + let migrator = Migrator::new(); + migrator + .migrate(&conn, None) + .map_err(DatabaseError::Migrator)?; + + Ok(SqliteDatabase { conn }) + } +} + +/// Database migrator +/// +/// The current database version is stored in the user_version PRAGMA. +/// It will iterate through all migrations and apply up migrations. +pub(crate) struct Migrator { + migrations: Vec, +} + +impl Migrator { + pub fn new() -> Self { + Self { + migrations: MIGRATIONS.to_vec(), + } + } + + pub fn migrate( + &self, + conn: &Connection, + target_version: Option, + ) -> Result<(), MigratorError> { + let current_version = user_version(conn) + .map_err(|_| MigratorError::Internal("Failed to get user_version".into()))? + as usize; + + let target_version = target_version.unwrap_or(MIGRATIONS.len()); + + let migrations = filter_migrations(&self.migrations, current_version, target_version); + + info!( + "Migrating database. Current version: {}, Target version: {}", + current_version, target_version + ); + + for migration in migrations { + info!("Applying migration: {}", migration.description); + + match current_version < target_version { + true => { + conn.execute_batch(migration.up) + .map_err(|_| MigratorError::Internal("Failed to apply migration".into()))?; + } + false => { + conn.execute_batch(migration.down) + .map_err(|_| MigratorError::Internal("Failed to apply migration".into()))?; + } + } + } + + set_user_version(conn, target_version as i32) + .map_err(|_| MigratorError::Internal("Failed to set user_version".into()))?; + + Ok(()) + } +} + +/// Filter the migrations to apply based on the current and target version +fn filter_migrations( + migrations: &[Migration], + current_version: usize, + target_version: usize, +) -> Vec<&Migration> { + match current_version.cmp(&target_version) { + Ordering::Less => migrations + .iter() + .skip(current_version) + .take(target_version - current_version) + .collect(), + Ordering::Greater => migrations + .iter() + .skip(target_version) + .take(current_version - target_version) + .rev() + .collect(), + Ordering::Equal => Vec::new(), + } +} + +#[derive(Clone, PartialEq, Eq, Debug)] +struct Migration { + /// A description of the migration, used for logging + description: &'static str, + /// The SQL to run when migrating up + up: &'static str, + /// The SQL to run when migrating down + down: &'static str, +} + +const MIGRATIONS: &[Migration] = &[Migration { + description: "Create ciphers table", + up: "CREATE TABLE IF NOT EXISTS ciphers ( + id TEXT PRIMARY KEY, + value TEXT NOT NULL + )", + down: "DROP TABLE ciphers", +}]; + +/// Get the user_version of the database +fn user_version(conn: &Connection) -> Result { + conn.query_row("PRAGMA user_version", [], |row| row.get(0)) + .map_err(|e| e.into()) +} + +/// Set the user_version of the database +fn set_user_version(conn: &Connection, version: i32) -> Result<(), DatabaseError> { + conn.pragma_update(None, "user_version", version)?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + + use super::*; + + const MIGRATIONS: &[Migration] = &[ + Migration { + description: "Create ciphers table", + up: "CREATE TABLE IF NOT EXISTS ciphers ( + id TEXT PRIMARY KEY, + value TEXT NOT NULL + )", + down: "DROP TABLE ciphers", + }, + Migration { + description: "Create folders table", + up: "CREATE TABLE IF NOT EXISTS folders ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL + )", + down: "DROP TABLE folders", + }, + Migration { + description: "Create collections table", + up: "CREATE TABLE IF NOT EXISTS collections ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL + )", + down: "DROP TABLE collections", + }, + ]; + + #[test] + fn test_filter_migrations() { + let result = filter_migrations(MIGRATIONS, 0, 3); + + assert_eq!(result[0].description, "Create ciphers table"); + assert_eq!(result[1].description, "Create folders table"); + assert_eq!(result[2].description, "Create collections table"); + } + + #[test] + fn test_filter_migrations_less() { + let result = filter_migrations(MIGRATIONS, 1, 2); + assert_eq!(result.len(), 1); + assert_eq!(result[0].description, "Create folders table"); + } + + #[test] + fn test_filter_migrations_greater() { + let result = filter_migrations(MIGRATIONS, 2, 0); + assert_eq!(result.len(), 2); + assert_eq!(result[0].description, "Create folders table"); + assert_eq!(result[1].description, "Create ciphers table"); + } + + #[test] + fn test_filter_migrations_equal() { + let result = filter_migrations(MIGRATIONS, 1, 1); + assert_eq!(result.len(), 0); + } +} diff --git a/crates/bitwarden-core/src/lib.rs b/crates/bitwarden-core/src/lib.rs index 409f0133b..721e85371 100644 --- a/crates/bitwarden-core/src/lib.rs +++ b/crates/bitwarden-core/src/lib.rs @@ -12,6 +12,8 @@ pub use error::Error; #[cfg(feature = "internal")] pub mod mobile; pub use error::{MissingFieldError, VaultLocked}; +mod database; +pub use database::{DatabaseError, SqliteDatabase}; #[cfg(feature = "internal")] pub mod platform; #[cfg(feature = "secrets")] diff --git a/crates/bitwarden-vault/Cargo.toml b/crates/bitwarden-vault/Cargo.toml index e875e4c06..ed0acd5f0 100644 --- a/crates/bitwarden-vault/Cargo.toml +++ b/crates/bitwarden-vault/Cargo.toml @@ -32,6 +32,7 @@ chrono = { version = ">=0.4.26, <0.5", features = [ rand = ">=0.8.5, <0.9" hmac = ">=0.12.1, <0.13" reqwest = { version = ">=0.12.5, <0.13", default-features = false } +rusqlite = { version = ">=0.31.0, <0.32", features = ["uuid"] } schemars = { version = ">=0.8.9, <0.9", features = ["uuid1", "chrono"] } serde = { version = ">=1.0, <2.0", features = ["derive"] } serde_json = ">=1.0.96, <2.0" diff --git a/crates/bitwarden-vault/src/cipher/mod.rs b/crates/bitwarden-vault/src/cipher/mod.rs index 67513d524..2d792bb87 100644 --- a/crates/bitwarden-vault/src/cipher/mod.rs +++ b/crates/bitwarden-vault/src/cipher/mod.rs @@ -7,6 +7,7 @@ pub(crate) mod identity; pub(crate) mod linked_id; pub(crate) mod local_data; pub(crate) mod login; +pub(crate) mod repository; pub(crate) mod secure_note; pub use attachment::{ diff --git a/crates/bitwarden-vault/src/cipher/repository.rs b/crates/bitwarden-vault/src/cipher/repository.rs new file mode 100644 index 000000000..8b368e7f3 --- /dev/null +++ b/crates/bitwarden-vault/src/cipher/repository.rs @@ -0,0 +1,206 @@ +use std::sync::{Arc, Mutex}; + +use bitwarden_core::{require, DatabaseError, Error, SqliteDatabase}; +use rusqlite::params; +use uuid::Uuid; + +use super::Cipher; + +pub trait CipherRepository { + /// Save a cipher to the repository. + fn save(&self, cipher: &Cipher) -> Result<(), DatabaseError>; + + /// Replace all ciphers in the repository with the given ciphers. + /// + /// Typically used during a sync operation. + fn replace_all(&mut self, ciphers: &[Cipher]) -> Result<(), DatabaseError>; + + /// Delete a cipher by its ID. + fn delete_by_id(&self, id: Uuid) -> Result<(), DatabaseError>; + + /// Get all ciphers from the repository. + fn get_all(&self) -> Result, DatabaseError>; +} + +/// A row in the ciphers table. +struct CipherRow { + #[allow(dead_code)] + id: Uuid, + value: String, +} + +pub struct CipherSqliteRepository { + db: Arc>, +} + +impl CipherSqliteRepository { + pub fn new(db: Arc>) -> Self { + Self { db: db.clone() } + } +} + +impl CipherRepository for CipherSqliteRepository { + fn save(&self, cipher: &Cipher) -> Result<(), DatabaseError> { + let id = require!(cipher.id); + let serialized = serde_json::to_string(cipher)?; + + let guard = self.db.lock().map_err(|_| DatabaseError::DatabaseLock)?; + + let mut stmt = guard.conn.prepare( + " + INSERT INTO ciphers (id, value) + VALUES (?1, ?2) + ON CONFLICT(id) DO UPDATE SET + value = ?2 + ", + )?; + stmt.execute((&id, &serialized))?; + + Ok(()) + } + + fn replace_all(&mut self, ciphers: &[Cipher]) -> Result<(), DatabaseError> { + let mut guard = self.db.lock().map_err(|_| DatabaseError::DatabaseLock)?; + + let tx = guard.conn.transaction()?; + { + tx.execute("DELETE FROM ciphers", [])?; + + let mut stmt = tx.prepare( + " + INSERT INTO ciphers (id, value) + VALUES (?1, ?2) + ", + )?; + + for cipher in ciphers { + let id = require!(cipher.id); + let serialized = serde_json::to_string(&cipher)?; + + stmt.execute(params![id, serialized])?; + } + } + tx.commit()?; + + Ok(()) + } + + fn delete_by_id(&self, id: Uuid) -> Result<(), DatabaseError> { + let guard = self.db.lock().map_err(|_| DatabaseError::DatabaseLock)?; + + let mut stmt = guard.conn.prepare("DELETE FROM ciphers WHERE id = ?1")?; + stmt.execute(params![id])?; + + Ok(()) + } + + fn get_all(&self) -> Result, DatabaseError> { + let guard = self.db.lock().map_err(|_| DatabaseError::DatabaseLock)?; + + let mut stmt = guard.conn.prepare("SELECT id, value FROM ciphers")?; + let rows = stmt.query_map([], |row| { + Ok(CipherRow { + id: row.get(0)?, + value: row.get(1)?, + }) + })?; + + let ciphers: Vec = rows + .flatten() + .flat_map(|row| -> Result { + let cipher: Cipher = serde_json::from_str(&row.value)?; + Ok(cipher) + }) + .collect(); + + Ok(ciphers) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{CipherRepromptType, CipherType}; + + fn mock_cipher(id: Uuid) -> Cipher { + Cipher { + id: Some(id), + organization_id: None, + folder_id: None, + collection_ids: vec![], + key: None, + name: "2.pMS6/icTQABtulw52pq2lg==|XXbxKxDTh+mWiN1HjH2N1w==|Q6PkuT+KX/axrgN9ubD5Ajk2YNwxQkgs3WJM0S0wtG8=".parse().unwrap(), + notes: None, + r#type: CipherType::Login, + login: None, + identity: None, + card: None, + secure_note: None, + favorite: false, + reprompt: CipherRepromptType::None, + organization_use_totp: false, + edit: true, + view_password: true, + local_data: None, + attachments: None, + fields: None, + password_history: None, + creation_date: "2024-01-30T17:55:36.150Z".parse().unwrap(), + deleted_date: None, + revision_date: "2024-01-30T17:55:36.150Z".parse().unwrap(), + } + } + + #[test] + fn test_save_get_all() { + let repo = CipherSqliteRepository::new(Arc::new(Mutex::new(SqliteDatabase::new_test()))); + + let cipher = mock_cipher("d55d65d7-c161-40a4-94ca-b0d20184d91a".parse().unwrap()); + + repo.save(&cipher).unwrap(); + + let ciphers = repo.get_all().unwrap(); + + assert_eq!(ciphers.len(), 1); + assert_eq!(ciphers[0].id, cipher.id); + } + + #[test] + fn test_delete_by_id() { + let repo = CipherSqliteRepository::new(Arc::new(Mutex::new(SqliteDatabase::new_test()))); + + let cipher = mock_cipher("d55d65d7-c161-40a4-94ca-b0d20184d91a".parse().unwrap()); + repo.save(&cipher).unwrap(); + + let ciphers = repo.get_all().unwrap(); + assert_eq!(ciphers.len(), 1); + + repo.delete_by_id(cipher.id.unwrap()).unwrap(); + let ciphers = repo.get_all().unwrap(); + assert_eq!(ciphers.len(), 0); + } + + #[test] + fn test_replace_all() { + let mut repo = + CipherSqliteRepository::new(Arc::new(Mutex::new(SqliteDatabase::new_test()))); + + let old_cipher = mock_cipher("d55d65d7-c161-40a4-94ca-b0d20184d91a".parse().unwrap()); + + repo.save(&old_cipher).unwrap(); + + let ciphers = repo.get_all().unwrap(); + assert_eq!(ciphers.len(), 1); + assert_eq!(ciphers[0].id, old_cipher.id); + + let new_ciphers = vec![mock_cipher( + "d55d65d7-c161-40a4-94ca-b0d20184d91c".parse().unwrap(), + )]; + + repo.replace_all(new_ciphers.as_slice()).unwrap(); + + let ciphers = repo.get_all().unwrap(); + assert_eq!(ciphers.len(), 1); + assert_eq!(ciphers[0].id, new_ciphers[0].id); + } +} diff --git a/crates/bitwarden-vault/src/client_vault.rs b/crates/bitwarden-vault/src/client_vault.rs index 67f6792b9..b446b2e59 100644 --- a/crates/bitwarden-vault/src/client_vault.rs +++ b/crates/bitwarden-vault/src/client_vault.rs @@ -1,17 +1,23 @@ use bitwarden_core::Client; use crate::{ + repository::{CipherRepository, CipherSqliteRepository}, sync::{sync, SyncError}, SyncRequest, SyncResponse, }; pub struct ClientVault<'a> { pub(crate) client: &'a Client, + pub cipher_repository: Box, } impl<'a> ClientVault<'a> { pub fn new(client: &'a Client) -> Self { - Self { client } + let t = client.internal.db.clone(); + Self { + client, + cipher_repository: Box::new(CipherSqliteRepository::new(t)), + } } pub async fn sync(&self, input: &SyncRequest) -> Result { diff --git a/crates/bitwarden-vault/src/sync.rs b/crates/bitwarden-vault/src/sync.rs index 6b4845d35..4b3f99ac1 100644 --- a/crates/bitwarden-vault/src/sync.rs +++ b/crates/bitwarden-vault/src/sync.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; use uuid::Uuid; -use crate::{Cipher, Collection, Folder, GlobalDomains, VaultParseError}; +use crate::{Cipher, ClientVaultExt, Collection, Folder, GlobalDomains, VaultParseError}; #[derive(Debug, Error)] pub enum SyncError { @@ -46,7 +46,18 @@ pub(crate) async fn sync(client: &Client, input: &SyncRequest) -> Result, + }, +} + +/// The main struct to interact with the Bitwarden SDK. +#[derive(Debug)] +pub struct Client { + token: Option, + pub(crate) refresh_token: Option, + pub(crate) token_expires_on: Option, + pub(crate) login_method: Option, + + #[cfg(feature = "internal")] + flags: Flags, + + /// Use Client::get_api_configurations() to access this. + /// It should only be used directly in renew_token + #[doc(hidden)] + pub(crate) __api_configurations: ApiConfigurations, + + encryption_settings: Option, + pub db: Arc>, +} + +impl Client { + pub fn new(settings_input: Option) -> Self { + let settings = settings_input.unwrap_or_default(); + + fn new_client_builder() -> reqwest::ClientBuilder { + #[allow(unused_mut)] + let mut client_builder = reqwest::Client::builder(); + + #[cfg(all(not(target_os = "android"), not(target_arch = "wasm32")))] + { + client_builder = + client_builder.use_preconfigured_tls(rustls_platform_verifier::tls_config()); + } + + client_builder + } + + let external_client = new_client_builder().build().expect("Build should not fail"); + + let mut headers = header::HeaderMap::new(); + headers.append( + "Device-Type", + HeaderValue::from_str(&(settings.device_type as u8).to_string()) + .expect("All numbers are valid ASCII"), + ); + let client_builder = new_client_builder().default_headers(headers); + + let client = client_builder.build().expect("Build should not fail"); + + let identity = bitwarden_api_identity::apis::configuration::Configuration { + base_path: settings.identity_url, + user_agent: Some(settings.user_agent.clone()), + client: client.clone(), + basic_auth: None, + oauth_access_token: None, + bearer_access_token: None, + api_key: None, + }; + + let api = bitwarden_api_api::apis::configuration::Configuration { + base_path: settings.api_url, + user_agent: Some(settings.user_agent), + client, + basic_auth: None, + oauth_access_token: None, + bearer_access_token: None, + api_key: None, + }; + + Self { + token: None, + refresh_token: None, + token_expires_on: None, + login_method: None, + #[cfg(feature = "internal")] + flags: Flags::default(), + __api_configurations: ApiConfigurations { + identity, + api, + external_client, + device_type: settings.device_type, + }, + encryption_settings: None, + db: Arc::new(Mutex::new(SqliteDatabase::default().unwrap())), + } + } + + #[cfg(feature = "internal")] + pub fn load_flags(&mut self, flags: std::collections::HashMap) { + self.flags = Flags::load_from_map(flags); + } + + #[cfg(feature = "internal")] + pub(crate) fn get_flags(&self) -> &Flags { + &self.flags + } + + pub(crate) async fn get_api_configurations(&mut self) -> &ApiConfigurations { + // At the moment we ignore the error result from the token renewal, if it fails, + // the token will end up expiring and the next operation is going to fail anyway. + self.auth().renew_token().await.ok(); + &self.__api_configurations + } + + #[cfg(feature = "internal")] + pub(crate) fn get_http_client(&self) -> &reqwest::Client { + &self.__api_configurations.external_client + } + + #[cfg(feature = "internal")] + pub(crate) fn get_login_method(&self) -> &Option { + &self.login_method + } + + pub fn get_access_token_organization(&self) -> Option { + match self.login_method { + Some(LoginMethod::ServiceAccount(ServiceAccountLoginMethod::AccessToken { + organization_id, + .. + })) => Some(organization_id), + _ => None, + } + } + + pub(crate) fn get_encryption_settings(&self) -> Result<&EncryptionSettings> { + self.encryption_settings.as_ref().ok_or(Error::VaultLocked) + } + + pub(crate) fn set_login_method(&mut self, login_method: LoginMethod) { + use log::debug; + + debug! {"setting login method: {:#?}", login_method} + self.login_method = Some(login_method); + } + + pub(crate) fn set_tokens( + &mut self, + token: String, + refresh_token: Option, + expires_in: u64, + ) { + self.token = Some(token.clone()); + self.refresh_token = refresh_token; + self.token_expires_on = Some(Utc::now().timestamp() + expires_in as i64); + self.__api_configurations.identity.oauth_access_token = Some(token.clone()); + self.__api_configurations.api.oauth_access_token = Some(token); + } + + #[cfg(feature = "internal")] + pub fn is_authed(&self) -> bool { + self.token.is_some() || self.login_method.is_some() + } + + #[cfg(feature = "internal")] + pub(crate) fn initialize_user_crypto_master_key( + &mut self, + master_key: MasterKey, + user_key: EncString, + private_key: EncString, + ) -> Result<&EncryptionSettings> { + Ok(self.encryption_settings.insert(EncryptionSettings::new( + master_key, + user_key, + private_key, + )?)) + } + + #[cfg(feature = "internal")] + pub(crate) fn initialize_user_crypto_decrypted_key( + &mut self, + user_key: SymmetricCryptoKey, + private_key: EncString, + ) -> Result<&EncryptionSettings> { + Ok(self + .encryption_settings + .insert(EncryptionSettings::new_decrypted_key( + user_key, + private_key, + )?)) + } + + #[cfg(feature = "internal")] + pub(crate) fn initialize_user_crypto_pin( + &mut self, + pin_key: MasterKey, + pin_protected_user_key: EncString, + private_key: EncString, + ) -> Result<&EncryptionSettings> { + let decrypted_user_key = pin_key.decrypt_user_key(pin_protected_user_key)?; + self.initialize_user_crypto_decrypted_key(decrypted_user_key, private_key) + } + + pub(crate) fn initialize_crypto_single_key( + &mut self, + key: SymmetricCryptoKey, + ) -> &EncryptionSettings { + self.encryption_settings + .insert(EncryptionSettings::new_single_key(key)) + } + + #[cfg(feature = "internal")] + pub(crate) fn initialize_org_crypto( + &mut self, + org_keys: Vec<(Uuid, AsymmetricEncString)>, + ) -> Result<&EncryptionSettings> { + let enc = self + .encryption_settings + .as_mut() + .ok_or(Error::VaultLocked)?; + + enc.set_org_keys(org_keys)?; + Ok(&*enc) + } +} + +#[cfg(test)] +mod tests { + #[test] + fn test_reqwest_rustls_platform_verifier_are_compatible() { + // rustls-platform-verifier is generating a rustls::ClientConfig, + // which reqwest accepts as a &dyn Any and then downcasts it to a + // rustls::ClientConfig. + + // This means that if the rustls version of the two crates don't match, + // the downcast will fail and we will get a runtime error. + + // This tests is added to ensure that it doesn't happen. + + let _ = reqwest::ClientBuilder::new() + .use_preconfigured_tls(rustls_platform_verifier::tls_config()) + .build() + .unwrap(); + } +} diff --git a/crates/bitwarden/src/client/mod.rs b/crates/bitwarden/src/client/mod.rs new file mode 100644 index 000000000..87a8315c2 --- /dev/null +++ b/crates/bitwarden/src/client/mod.rs @@ -0,0 +1,15 @@ +//! Bitwarden SDK Client + +pub(crate) use client::*; +#[allow(clippy::module_inception)] +mod client; +pub mod client_settings; +pub(crate) mod database; +pub(crate) use database::SqliteDatabase; +pub(crate) mod encryption_settings; + +#[cfg(feature = "internal")] +mod flags; + +pub use client::Client; +pub use client_settings::{ClientSettings, DeviceType}; diff --git a/crates/bitwarden/src/vault/client_vault.rs b/crates/bitwarden/src/vault/client_vault.rs new file mode 100644 index 000000000..b53ad7644 --- /dev/null +++ b/crates/bitwarden/src/vault/client_vault.rs @@ -0,0 +1,34 @@ +use super::{ + repository::CipherSqliteRepository, + sync::{sync, SyncRequest, SyncResponse}, +}; +use crate::{error::Result, vault::cipher::repository::CipherRepository, Client}; + +pub struct ClientVault<'a> { + pub(crate) client: &'a mut crate::Client, + pub cipher_repository: Box, +} + +impl<'a> ClientVault<'a> { + pub async fn sync(&mut self, input: &SyncRequest) -> Result { + sync(self.client, input).await + } +} + +impl<'a> Client { + pub fn vault(&'a mut self) -> ClientVault<'a> { + let t = self.db.clone(); + ClientVault { + client: self, + cipher_repository: Box::new(CipherSqliteRepository::new(t)), + } + } +} + +pub struct ClientRepositories {} + +impl std::fmt::Debug for ClientRepositories { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClientRepositories").finish() + } +}