diff --git a/Cargo.lock b/Cargo.lock index 01b03a5..826e338 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -102,6 +102,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "async-trait" +version = "0.1.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -1169,6 +1180,7 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" name = "sqlite-studio" version = "0.1.5" dependencies = [ + "async-trait", "chrono", "clap", "color-eyre", diff --git a/Cargo.toml b/Cargo.toml index b9f7489..c69a8da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } warp = "0.3.7" open = "3.2.0" +async-trait = "0.1.80" [profile.release] strip = true diff --git a/src/main.rs b/src/main.rs index c2734f9..4c4fcd9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,7 @@ -use std::{collections::HashMap, path::Path, sync::Arc}; - +use async_trait::async_trait; use clap::Parser; -use color_eyre::eyre::OptionExt; -use tokio_rusqlite::{Connection, OpenFlags}; use warp::Filter; + const ROWS_PER_PAGE: i32 = 50; const SAMPLE_DB: &[u8] = include_bytes!("../sample.sqlite3"); @@ -33,9 +31,9 @@ async fn main() -> color_eyre::Result<()> { let args = Args::parse(); let db = if args.database == "preview" { tokio::fs::write("sample.db", SAMPLE_DB).await?; - TheDB::open("sample.db".to_string()).await? + sqlite::Db::open("sample.db".to_string()).await? } else { - TheDB::open(args.database).await? + sqlite::Db::open(args.database).await? }; let cors = warp::cors() @@ -121,302 +119,330 @@ mod statics { } } -#[derive(Clone)] -struct TheDB { - path: String, - conn: Arc, +#[async_trait] +trait Database: Sized + Clone + Send { + async fn open(path: String) -> color_eyre::Result; + + async fn overview(&self) -> color_eyre::Result; + + async fn tables(&self) -> color_eyre::Result; + + async fn table(&self, name: String) -> color_eyre::Result; + + async fn table_data(&self, name: String, page: i32) + -> color_eyre::Result; + + async fn query(&self, query: String) -> color_eyre::Result; } -impl TheDB { - async fn open(path: String) -> color_eyre::Result { - let conn = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_ONLY).await?; +mod sqlite { + use async_trait::async_trait; + use color_eyre::eyre::OptionExt; + use std::{collections::HashMap, path::Path, sync::Arc}; + use tokio_rusqlite::{Connection, OpenFlags}; - // This is meant to test if the file at path is actually a DB. - let tables = conn - .call(|conn| { - Ok(conn.query_row( - r#" + use crate::{helpers, responses, Database, ROWS_PER_PAGE}; + + #[derive(Clone)] + pub struct Db { + path: String, + conn: Arc, + } + + #[async_trait] + impl Database for Db { + async fn open(path: String) -> color_eyre::Result { + let conn = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_ONLY).await?; + + // This is meant to test if the file at path is actually a DB. + let tables = conn + .call(|conn| { + Ok(conn.query_row( + r#" SELECT count(*) FROM sqlite_master WHERE type="table" "#, - (), - |r| r.get::<_, i32>(0), - )?) - }) - .await?; + (), + |r| r.get::<_, i32>(0), + )?) + }) + .await?; - tracing::info!("found {tables} tables in {path}"); - Ok(Self { - path, - conn: Arc::new(conn), - }) - } + tracing::info!("found {tables} tables in {path}"); + Ok(Self { + path, + conn: Arc::new(conn), + }) + } - async fn overview(&self) -> color_eyre::Result { - let file_name = Path::new(&self.path) - .file_name() - .ok_or_eyre("failed to get file name overview")? - .to_str() - .ok_or_eyre("file name is not utf-8")? - .to_owned(); - - let metadata = tokio::fs::metadata(&self.path).await?; - - let sqlite_version = tokio_rusqlite::version().to_owned(); - let file_size = helpers::format_size(metadata.len() as f64); - let modified = metadata.modified()?.into(); - let created = metadata.created().ok().map(Into::into); - - let (tables, indexes, triggers, views, counts) = self - .conn - .call(move |conn| { - let tables = conn.query_row( - r#" + async fn overview(&self) -> color_eyre::Result { + let file_name = Path::new(&self.path) + .file_name() + .ok_or_eyre("failed to get file name overview")? + .to_str() + .ok_or_eyre("file name is not utf-8")? + .to_owned(); + + let metadata = tokio::fs::metadata(&self.path).await?; + + let sqlite_version = tokio_rusqlite::version().to_owned(); + let file_size = helpers::format_size(metadata.len() as f64); + let modified = metadata.modified()?.into(); + let created = metadata.created().ok().map(Into::into); + + let (tables, indexes, triggers, views, counts) = self + .conn + .call(move |conn| { + let tables = conn.query_row( + r#" SELECT count(*) FROM sqlite_master WHERE type="table" "#, - (), - |r| r.get::<_, i32>(0), - )?; + (), + |r| r.get::<_, i32>(0), + )?; - let indexes = conn.query_row( - r#" + let indexes = conn.query_row( + r#" SELECT count(*) FROM sqlite_master WHERE type="index" "#, - (), - |r| r.get::<_, i32>(0), - )?; + (), + |r| r.get::<_, i32>(0), + )?; - let triggers = conn.query_row( - r#" + let triggers = conn.query_row( + r#" SELECT count(*) FROM sqlite_master WHERE type="trigger" "#, - (), - |r| r.get::<_, i32>(0), - )?; + (), + |r| r.get::<_, i32>(0), + )?; - let views = conn.query_row( - r#" + let views = conn.query_row( + r#" SELECT count(*) FROM sqlite_master WHERE type="view" "#, - (), - |r| r.get::<_, i32>(0), - )?; - - let mut stmt = - conn.prepare(r#"SELECT name FROM sqlite_master WHERE type="table""#)?; - let table_names = stmt.query_map([], |row| row.get::<_, String>(0))?; - - let mut table_counts = HashMap::with_capacity(tables as usize); - for name in table_names { - let name = name?; - let count = - conn.query_row(&format!("SELECT count(*) FROM '{name}'"), (), |r| { - r.get::<_, i32>(0) - })?; + (), + |r| r.get::<_, i32>(0), + )?; - table_counts.insert(name, count); - } + let mut stmt = + conn.prepare(r#"SELECT name FROM sqlite_master WHERE type="table""#)?; + let table_names = stmt.query_map([], |row| row.get::<_, String>(0))?; - let mut counts = table_counts - .into_iter() - .map(|(name, count)| responses::RowCount { name, count }) - .collect::>(); + let mut table_counts = HashMap::with_capacity(tables as usize); + for name in table_names { + let name = name?; + let count = + conn.query_row(&format!("SELECT count(*) FROM '{name}'"), (), |r| { + r.get::<_, i32>(0) + })?; - counts.sort_by(|a, b| b.count.cmp(&a.count)); + table_counts.insert(name, count); + } - Ok((tables, indexes, triggers, views, counts)) - }) - .await?; - - Ok(responses::Overview { - file_name, - sqlite_version, - file_size, - created, - modified, - tables, - indexes, - triggers, - views, - counts, - }) - } + let mut counts = table_counts + .into_iter() + .map(|(name, count)| responses::RowCount { name, count }) + .collect::>(); - async fn tables(&self) -> color_eyre::Result { - let tables = self - .conn - .call(move |conn| { - let mut stmt = - conn.prepare(r#"SELECT name FROM sqlite_master WHERE type="table""#)?; - let table_names = stmt - .query_map([], |row| row.get::<_, String>(0))? - .collect::>(); - - let mut table_counts = HashMap::with_capacity(table_names.len()); - for name in table_names { - let name = name?; - let count = - conn.query_row(&format!("SELECT count(*) FROM '{name}'"), (), |r| { - r.get::<_, i32>(0) - })?; + counts.sort_by(|a, b| b.count.cmp(&a.count)); - table_counts.insert(name, count); - } + Ok((tables, indexes, triggers, views, counts)) + }) + .await?; + + Ok(responses::Overview { + file_name, + sqlite_version, + file_size, + created, + modified, + tables, + indexes, + triggers, + views, + counts, + }) + } - let mut counts = table_counts - .into_iter() - .map(|(name, count)| responses::RowCount { name, count }) - .collect::>(); + async fn tables(&self) -> color_eyre::Result { + let tables = self + .conn + .call(move |conn| { + let mut stmt = + conn.prepare(r#"SELECT name FROM sqlite_master WHERE type="table""#)?; + let table_names = stmt + .query_map([], |row| row.get::<_, String>(0))? + .collect::>(); + + let mut table_counts = HashMap::with_capacity(table_names.len()); + for name in table_names { + let name = name?; + let count = + conn.query_row(&format!("SELECT count(*) FROM '{name}'"), (), |r| { + r.get::<_, i32>(0) + })?; + + table_counts.insert(name, count); + } + + let mut counts = table_counts + .into_iter() + .map(|(name, count)| responses::RowCount { name, count }) + .collect::>(); + + counts.sort_by_key(|r| r.count); + + Ok(counts) + }) + .await?; - counts.sort_by_key(|r| r.count); + Ok(responses::Tables { tables }) + } - Ok(counts) - }) - .await?; + async fn table(&self, name: String) -> color_eyre::Result { + let metadata = tokio::fs::metadata(&self.path).await?; + let more_than_five = metadata.len() > 5_000_000_000; + + Ok(self + .conn + .call(move |conn| { + let sql = conn.query_row( + r#" + SELECT sql FROM sqlite_master + WHERE type="table" AND name = ?1 + "#, + [&name], + |r| r.get::<_, String>(0), + )?; - Ok(responses::Tables { tables }) - } + let row_count = + conn.query_row(&format!("SELECT count(*) FROM '{name}'"), (), |r| { + r.get::<_, i32>(0) + })?; - async fn table(&self, name: String) -> color_eyre::Result { - let metadata = tokio::fs::metadata(&self.path).await?; - let more_than_five = metadata.len() > 5_000_000_000; - - Ok(self - .conn - .call(move |conn| { - let sql = conn.query_row( - r#" - SELECT sql FROM sqlite_master WHERE type="table" AND name = ?1 - "#, - [&name], - |r| r.get::<_, String>(0), - )?; - - let row_count = - conn.query_row(&format!("SELECT count(*) FROM '{name}'"), (), |r| { - r.get::<_, i32>(0) - })?; - - let table_size = if more_than_five { - "---".to_owned() - } else { - let table_size = conn.query_row( - "SELECT SUM(pgsize) FROM dbstat WHERE name = ?1", + let table_size = if more_than_five { + "---".to_owned() + } else { + let table_size = conn.query_row( + "SELECT SUM(pgsize) FROM dbstat WHERE name = ?1", + [&name], + |r| r.get::<_, i64>(0), + )?; + helpers::format_size(table_size as f64) + }; + + let index_count = conn.query_row( + "SELECT count(*) FROM sqlite_master WHERE type='index' AND tbl_name=?1", [&name], - |r| r.get::<_, i64>(0), + |r| r.get::<_, i32>(0), )?; - helpers::format_size(table_size as f64) - }; - - let index_count = conn.query_row( - "SELECT count(*) FROM sqlite_master WHERE type='index' AND tbl_name=?1", - [&name], - |r| r.get::<_, i32>(0), - )?; - - let has_primary_key = - conn.query_row(&format!("PRAGMA table_info('{name}')"), [], |r| { - r.get::<_, i32>(5) - })? == 1; - let index_count = if has_primary_key { - index_count + 1 - } else { - index_count - }; - - let mut columns = conn.prepare(&format!("PRAGMA table_info('{name}')"))?; - let column_count = columns.query_map((), |r| r.get::<_, String>(1))?.count() as i32; - - Ok(responses::Table { - name, - sql, - row_count, - table_size, - index_count, - column_count, + + let has_primary_key = + conn.query_row(&format!("PRAGMA table_info('{name}')"), [], |r| { + r.get::<_, i32>(5) + })? == 1; + let index_count = if has_primary_key { + index_count + 1 + } else { + index_count + }; + + let mut columns = conn.prepare(&format!("PRAGMA table_info('{name}')"))?; + let column_count = + columns.query_map((), |r| r.get::<_, String>(1))?.count() as i32; + + Ok(responses::Table { + name, + sql, + row_count, + table_size, + index_count, + column_count, + }) }) - }) - .await?) - } + .await?) + } - async fn table_data( - &self, - name: String, - page: i32, - ) -> color_eyre::Result { - Ok(self - .conn - .call(move |conn| { - let first_column = - conn.query_row(&format!("PRAGMA table_info('{name}')"), [], |r| { - r.get::<_, String>(1) - })?; - - let offset = (page - 1) * ROWS_PER_PAGE; - let mut stmt = conn.prepare(&format!( - r#" - SELECT * - FROM '{name}' - ORDER BY {first_column} - LIMIT {ROWS_PER_PAGE} - OFFSET {offset} - "# - ))?; - let columns = stmt - .column_names() - .into_iter() - .map(ToOwned::to_owned) - .collect::>(); - - let columns_len = columns.len(); - let rows = stmt - .query_map((), |r| { - let mut rows = Vec::with_capacity(columns_len); - for i in 0..columns_len { - let val = helpers::value_to_json(r.get_ref(i)?); - rows.push(val); - } - Ok(rows) - })? - .filter_map(|x| x.ok()) - .collect::>(); - - Ok(responses::TableData { columns, rows }) - }) - .await?) - } + async fn table_data( + &self, + name: String, + page: i32, + ) -> color_eyre::Result { + Ok(self + .conn + .call(move |conn| { + let first_column = + conn.query_row(&format!("PRAGMA table_info('{name}')"), [], |r| { + r.get::<_, String>(1) + })?; - async fn query(&self, query: String) -> color_eyre::Result { - Ok(self - .conn - .call(move |conn| { - let mut stmt = conn.prepare(&query)?; - let columns = stmt - .column_names() - .into_iter() - .map(ToOwned::to_owned) - .collect::>(); - - let columns_len = columns.len(); - let rows = stmt - .query_map((), |r| { - let mut rows = Vec::with_capacity(columns_len); - for i in 0..columns_len { - let val = helpers::value_to_json(r.get_ref(i)?); - rows.push(val); - } - Ok(rows) - })? - .filter_map(|x| x.ok()) - .collect::>(); - - Ok(responses::Query { columns, rows }) - }) - .await?) + let offset = (page - 1) * ROWS_PER_PAGE; + let mut stmt = conn.prepare(&format!( + r#" + SELECT * + FROM '{name}' + ORDER BY {first_column} + LIMIT {ROWS_PER_PAGE} + OFFSET {offset} + "# + ))?; + let columns = stmt + .column_names() + .into_iter() + .map(ToOwned::to_owned) + .collect::>(); + + let columns_len = columns.len(); + let rows = stmt + .query_map((), |r| { + let mut rows = Vec::with_capacity(columns_len); + for i in 0..columns_len { + let val = helpers::value_to_json(r.get_ref(i)?); + rows.push(val); + } + Ok(rows) + })? + .filter_map(|x| x.ok()) + .collect::>(); + + Ok(responses::TableData { columns, rows }) + }) + .await?) + } + + async fn query(&self, query: String) -> color_eyre::Result { + Ok(self + .conn + .call(move |conn| { + let mut stmt = conn.prepare(&query)?; + let columns = stmt + .column_names() + .into_iter() + .map(ToOwned::to_owned) + .collect::>(); + + let columns_len = columns.len(); + let rows = stmt + .query_map((), |r| { + let mut rows = Vec::with_capacity(columns_len); + for i in 0..columns_len { + let val = helpers::value_to_json(r.get_ref(i)?); + rows.push(val); + } + Ok(rows) + })? + .filter_map(|x| x.ok()) + .collect::>(); + + Ok(responses::Query { columns, rows }) + }) + .await?) + } } } @@ -502,7 +528,7 @@ mod handlers { use serde::Deserialize; use warp::Filter; - use crate::{rejections, TheDB}; + use crate::{rejections, Database}; fn with_state( state: &T, @@ -512,7 +538,7 @@ mod handlers { } pub fn routes( - db: TheDB, + db: impl Database, ) -> impl Filter + Clone { let overview = warp::path::end() .and(warp::get()) @@ -550,7 +576,7 @@ mod handlers { pub page: Option, } - async fn overview(db: TheDB) -> Result { + async fn overview(db: impl Database) -> Result { let overview = db.overview().await.map_err(|e| { tracing::error!("error while getting database overview: {e}"); warp::reject::custom(rejections::InternalServerError) @@ -558,7 +584,7 @@ mod handlers { Ok(warp::reply::json(&overview)) } - async fn tables(db: TheDB) -> Result { + async fn tables(db: impl Database) -> Result { let tables = db.tables().await.map_err(|e| { tracing::error!("error while getting tables: {e}"); warp::reject::custom(rejections::InternalServerError) @@ -566,7 +592,7 @@ mod handlers { Ok(warp::reply::json(&tables)) } - async fn table(db: TheDB, name: String) -> Result { + async fn table(db: impl Database, name: String) -> Result { let tables = db.table(name).await.map_err(|e| { tracing::error!("error while getting table: {e}"); warp::reject::custom(rejections::InternalServerError) @@ -575,7 +601,7 @@ mod handlers { } async fn table_data( - db: TheDB, + db: impl Database, name: String, data: PageQuery, ) -> Result { @@ -589,7 +615,10 @@ mod handlers { Ok(warp::reply::json(&data)) } - async fn query(db: TheDB, query: QueryBody) -> Result { + async fn query( + db: impl Database, + query: QueryBody, + ) -> Result { let tables = db .query(query.query) .await