diff --git a/Cargo.toml b/Cargo.toml index b60cca31..a1e7e0b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ modern-full = [ "chrono", "serde_json", "url", + "r2d2", ] [dependencies] @@ -48,6 +49,7 @@ cast = { version = "0.3", features = ["std"] } arrow = { version = "6.5.0", default-features = false, features = ["prettyprint"] } rust_decimal = "1.14" strum = { version = "0.23", features = ["derive"] } +r2d2 = { version = "0.8.9", optional = true } [dev-dependencies] doc-comment = "0.3" @@ -57,6 +59,7 @@ regex = "1.3" uuid = { version = "0.8", features = ["v4"] } unicase = "2.6.0" rand = "0.8.3" +tempdir = "0.3.7" # criterion = "0.3" # [[bench]] diff --git a/src/lib.rs b/src/lib.rs index 281b537f..c389a715 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,6 +78,8 @@ pub use crate::config::{AccessMode, Config, DefaultNullOrder, DefaultOrder}; pub use crate::error::Error; pub use crate::ffi::ErrorCode; pub use crate::params::{params_from_iter, Params, ParamsFromIter}; +#[cfg(feature = "r2d2")] +pub use crate::r2d2::DuckdbConnectionManager; pub use crate::row::{AndThenRows, Map, MappedRows, Row, RowIndex, Rows}; pub use crate::statement::Statement; pub use crate::transaction::{DropBehavior, Savepoint, Transaction, TransactionBehavior}; @@ -93,6 +95,8 @@ mod config; mod inner_connection; mod params; mod pragma; +#[cfg(feature = "r2d2")] +mod r2d2; mod raw_statement; mod row; mod statement; diff --git a/src/r2d2.rs b/src/r2d2.rs new file mode 100644 index 00000000..0f747246 --- /dev/null +++ b/src/r2d2.rs @@ -0,0 +1,230 @@ +#![deny(warnings)] +//! # Duckdb-rs support for the `r2d2` connection pool. +//! +//! +//! Integrated with: [r2d2](https://crates.io/crates/r2d2) +//! +//! +//! ## Example +//! +//! ```rust,no_run +//! extern crate r2d2; +//! extern crate duckdb; +//! +//! +//! use std::thread; +//! use duckdb::{DuckdbConnectionManager, params}; +//! +//! +//! fn main() { +//! let manager = DuckdbConnectionManager::file("file.db").unwrap(); +//! let pool = r2d2::Pool::new(manager).unwrap(); +//! pool.get() +//! .unwrap() +//! .execute("CREATE TABLE IF NOT EXISTS foo (bar INTEGER)", params![]) +//! .unwrap(); +//! +//! (0..10) +//! .map(|i| { +//! let pool = pool.clone(); +//! thread::spawn(move || { +//! let conn = pool.get().unwrap(); +//! conn.execute("INSERT INTO foo (bar) VALUES (?)", &[&i]) +//! .unwrap(); +//! }) +//! }) +//! .collect::>() +//! .into_iter() +//! .map(thread::JoinHandle::join) +//! .collect::>() +//! .unwrap() +//! } +//! ``` +use crate::{Config, Connection, Error, Result}; +use std::{ + path::Path, + sync::{Arc, Mutex}, +}; + +/// An `r2d2::ManageConnection` for `duckdb::Connection`s. +pub struct DuckdbConnectionManager { + connection: Arc>, +} + +impl DuckdbConnectionManager { + /// Creates a new `DuckdbConnectionManager` from file. + pub fn file>(path: P) -> Result { + Ok(Self { + connection: Arc::new(Mutex::new(Connection::open(path)?)), + }) + } + /// Creates a new `DuckdbConnectionManager` from file with flags. + pub fn file_with_flags>(path: P, config: Config) -> Result { + Ok(Self { + connection: Arc::new(Mutex::new(Connection::open_with_flags(path, config)?)), + }) + } + + /// Creates a new `DuckdbConnectionManager` from memory. + pub fn memory() -> Result { + Ok(Self { + connection: Arc::new(Mutex::new(Connection::open_in_memory()?)), + }) + } + + /// Creates a new `DuckdbConnectionManager` from memory with flags. + pub fn memory_with_flags(config: Config) -> Result { + Ok(Self { + connection: Arc::new(Mutex::new(Connection::open_in_memory_with_flags(config)?)), + }) + } +} + +impl r2d2::ManageConnection for DuckdbConnectionManager { + type Connection = Connection; + type Error = Error; + + fn connect(&self) -> Result { + let conn = self.connection.lock().unwrap(); + Ok(conn.clone()) + } + + fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + conn.execute_batch("").map_err(Into::into) + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } +} + +#[cfg(test)] +mod test { + extern crate r2d2; + use super::*; + use crate::types::Value; + use crate::Result; + use std::{sync::mpsc, thread}; + + use tempdir::TempDir; + + #[test] + fn test_basic() -> Result<()> { + let manager = DuckdbConnectionManager::file("file.db")?; + let pool = r2d2::Pool::builder().max_size(2).build(manager).unwrap(); + + let (s1, r1) = mpsc::channel(); + let (s2, r2) = mpsc::channel(); + + let pool1 = pool.clone(); + let t1 = thread::spawn(move || { + let conn = pool1.get().unwrap(); + s1.send(()).unwrap(); + r2.recv().unwrap(); + drop(conn); + }); + + let pool2 = pool.clone(); + let t2 = thread::spawn(move || { + let conn = pool2.get().unwrap(); + s2.send(()).unwrap(); + r1.recv().unwrap(); + drop(conn); + }); + + t1.join().unwrap(); + t2.join().unwrap(); + + pool.get().unwrap(); + Ok(()) + } + + #[test] + fn test_file() -> Result<()> { + let manager = DuckdbConnectionManager::file("file.db")?; + let pool = r2d2::Pool::builder().max_size(2).build(manager).unwrap(); + + let (s1, r1) = mpsc::channel(); + let (s2, r2) = mpsc::channel(); + + let pool1 = pool.clone(); + let t1 = thread::spawn(move || { + let conn = pool1.get().unwrap(); + s1.send(()).unwrap(); + r2.recv().unwrap(); + drop(conn); + }); + + let pool2 = pool.clone(); + let t2 = thread::spawn(move || { + let conn = pool2.get().unwrap(); + s2.send(()).unwrap(); + r1.recv().unwrap(); + drop(conn); + }); + + t1.join().unwrap(); + t2.join().unwrap(); + + pool.get().unwrap(); + Ok(()) + } + + #[test] + fn test_is_valid() -> Result<()> { + let manager = DuckdbConnectionManager::file("file.db")?; + let pool = r2d2::Pool::builder() + .max_size(1) + .test_on_check_out(true) + .build(manager) + .unwrap(); + + pool.get().unwrap(); + Ok(()) + } + + #[test] + fn test_error_handling() -> Result<()> { + //! We specify a directory as a database. This is bound to fail. + let dir = TempDir::new("r2d2-duckdb").expect("Could not create temporary directory"); + let dirpath = dir.path().to_str().unwrap(); + assert!(DuckdbConnectionManager::file(dirpath).is_err()); + Ok(()) + } + + #[test] + fn test_with_flags() -> Result<()> { + let config = Config::default() + .access_mode(crate::AccessMode::ReadWrite)? + .default_null_order(crate::DefaultNullOrder::NullsLast)? + .default_order(crate::DefaultOrder::Desc)? + .enable_external_access(true)? + .enable_object_cache(false)? + .max_memory("2GB")? + .threads(4)?; + let manager = DuckdbConnectionManager::file_with_flags("file.db", config)?; + let pool = r2d2::Pool::builder().max_size(2).build(manager).unwrap(); + let conn = pool.get().unwrap(); + conn.execute_batch("CREATE TABLE foo(x Text)")?; + + let mut stmt = conn.prepare("INSERT INTO foo(x) VALUES (?)")?; + stmt.execute(&[&"a"])?; + stmt.execute(&[&"b"])?; + stmt.execute(&[&"c"])?; + stmt.execute([Value::Null])?; + + let val: Result>> = conn + .prepare("SELECT x FROM foo ORDER BY x")? + .query_and_then([], |row| row.get(0))? + .collect(); + let val = val?; + let mut iter = val.iter(); + assert_eq!(iter.next().unwrap().as_ref().unwrap(), "c"); + assert_eq!(iter.next().unwrap().as_ref().unwrap(), "b"); + assert_eq!(iter.next().unwrap().as_ref().unwrap(), "a"); + assert!(iter.next().unwrap().is_none()); + assert_eq!(iter.next(), None); + + Ok(()) + } +}