diff --git a/src/db_reference.rs b/src/db_reference.rs new file mode 100644 index 0000000..7607c0c --- /dev/null +++ b/src/db_reference.rs @@ -0,0 +1,39 @@ +use rocksdb::DB; +use std::cell::RefCell; +use std::sync::Arc; + +/// The type of a reference to a [rocksdb::DB] that is passed around the library. +pub(crate) type DbReference = Arc>; + +/// A wrapper around [DbReference] that cancels all background work when dropped. +/// +/// All users of [rocksdb::DB] should use this wrapper instead to avoid keeping background threads +/// alive after the database is dropped. +#[derive(Clone)] +pub(crate) struct DbReferenceHolder { + inner: Option, +} + +impl DbReferenceHolder { + pub fn new(db: DB) -> Self { + Self { + inner: Some(Arc::new(RefCell::new(db))), + } + } + + pub fn get(&self) -> Option<&DbReference> { + self.inner.as_ref() + } + + pub fn close(&mut self) { + if let Some(db) = self.inner.take().and_then(Arc::into_inner) { + db.borrow_mut().cancel_all_background_work(true); + } + } +} + +impl Drop for DbReferenceHolder { + fn drop(&mut self) { + self.close(); + } +} diff --git a/src/iter.rs b/src/iter.rs index 0cbd42f..5da6823 100644 --- a/src/iter.rs +++ b/src/iter.rs @@ -1,12 +1,13 @@ +use crate::db_reference::DbReferenceHolder; use crate::encoder::{decode_value, encode_key}; +use crate::exceptions::DbClosedError; use crate::util::error_message; use crate::{ReadOpt, ReadOptionsPy}; use core::slice; use libc::{c_char, c_uchar, size_t}; use pyo3::exceptions::PyException; use pyo3::prelude::*; -use rocksdb::{AsColumnFamilyRef, ColumnFamily, DB}; -use std::cell::RefCell; +use rocksdb::{AsColumnFamilyRef, ColumnFamily}; use std::ptr::null_mut; use std::sync::Arc; @@ -14,7 +15,7 @@ use std::sync::Arc; #[allow(dead_code)] pub(crate) struct RdictIter { /// iterator must keep a reference count of DB to keep DB alive. - pub(crate) db: Arc>, + pub(crate) db: DbReferenceHolder, pub(crate) inner: *mut librocksdb_sys::rocksdb_iterator_t, @@ -49,7 +50,7 @@ pub(crate) struct RdictValues { impl RdictIter { pub(crate) fn new( - db: &Arc>, + db: &DbReferenceHolder, cf: &Option>, readopts: ReadOptionsPy, pickle_loads: &PyObject, @@ -57,18 +58,21 @@ impl RdictIter { py: Python, ) -> PyResult { let readopts = readopts.to_read_opt(raw_mode, py)?; + + let db_inner = db + .get() + .ok_or_else(|| DbClosedError::new_err("DB instance already closed"))? + .borrow() + .inner(); + Ok(RdictIter { db: db.clone(), inner: unsafe { match cf { - None => { - librocksdb_sys::rocksdb_create_iterator(db.borrow().inner(), readopts.0) + None => librocksdb_sys::rocksdb_create_iterator(db_inner, readopts.0), + Some(cf) => { + librocksdb_sys::rocksdb_create_iterator_cf(db_inner, readopts.0, cf.inner()) } - Some(cf) => librocksdb_sys::rocksdb_create_iterator_cf( - db.borrow().inner(), - readopts.0, - cf.inner(), - ), } }, readopts, diff --git a/src/lib.rs b/src/lib.rs index 747212c..4b53fe5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ // #![feature(core_intrinsics)] +mod db_reference; mod encoder; mod exceptions; mod iter; diff --git a/src/rdict.rs b/src/rdict.rs index 1891f80..9804e3a 100644 --- a/src/rdict.rs +++ b/src/rdict.rs @@ -1,3 +1,4 @@ +use crate::db_reference::{DbReference, DbReferenceHolder}; use crate::encoder::{decode_value, encode_key, encode_value}; use crate::exceptions::DbClosedError; use crate::iter::{RdictItems, RdictKeys, RdictValues}; @@ -71,7 +72,7 @@ pub(crate) struct Rdict { pub(crate) access_type: AccessType, pub(crate) slice_transforms: Arc>>, // drop DB last - pub(crate) db: Option>>, + pub(crate) db: DbReferenceHolder, } /// Define DB Access Types. @@ -136,9 +137,9 @@ impl Rdict { .save(config_path) } - fn get_db(&self) -> PyResult<&Arc>> { + fn get_db(&self) -> PyResult<&DbReference> { self.db - .as_ref() + .get() .ok_or_else(|| DbClosedError::new_err("DB instance already closed")) } } @@ -266,7 +267,7 @@ impl Rdict { // save rocksdict config rocksdict_config.save(config_path)?; Ok(Rdict { - db: Some(Arc::new(RefCell::new(db))), + db: DbReferenceHolder::new(db), write_opt: (&w_opt).into(), flush_opt: FlushOptionsPy::new(), read_opt: r_opt.to_read_options(options.raw_mode, py)?, @@ -640,7 +641,7 @@ impl Rdict { }; RdictIter::new( - self.get_db()?, + &self.db, &self.column_family, read_opt, &self.loads, @@ -818,7 +819,7 @@ impl Rdict { "column name `{name}` does not exist, use `create_cf` to creat it", ))), Some(cf) => Ok(Self { - db: Some(db.clone()), + db: self.db.clone(), write_opt: (&self.write_opt_py).into(), flush_opt: self.flush_opt, read_opt: self.read_opt_py.to_read_options(self.opt_py.raw_mode, py)?, @@ -855,7 +856,10 @@ impl Rdict { None => Err(PyException::new_err(format!( "column name `{name}` does not exist, use `create_cf` to creat it", ))), - Some(cf) => Ok(ColumnFamilyPy { cf, db: db.clone() }), + Some(cf) => Ok(ColumnFamilyPy { + cf, + db: self.db.clone(), + }), } } @@ -1014,8 +1018,8 @@ impl Rdict { /// Other Column Family `Rdict` instances, `ColumnFamily` /// (cf handle) instances, iterator instances such as`RdictIter`, /// `RdictItems`, `RdictKeys`, `RdictValues` can all keep RocksDB - /// alive. `del` all associated instances mentioned above - /// to actually shut down RocksDB. + /// alive. `del` or `close` all associated instances mentioned + /// above to actually shut down RocksDB. /// fn close(&mut self) -> PyResult<()> { let f_opt = &self.flush_opt; @@ -1024,7 +1028,7 @@ impl Rdict { AccessTypeInner::ReadOnly { .. } | AccessTypeInner::Secondary { .. } => { drop(db); drop(self.column_family.take()); - drop(self.db.take()); + self.db.close(); return Ok(()); } _ => (), @@ -1037,7 +1041,7 @@ impl Rdict { }; drop(db); drop(self.column_family.take()); - drop(self.db.take()); + self.db.close(); match (flush_result, flush_wal_result) { (Ok(_), Ok(_)) => Ok(()), (Err(e), Ok(_)) => Err(PyException::new_err(e.to_string())), @@ -1257,7 +1261,7 @@ fn get_batch_inner<'a>( impl Drop for Rdict { // flush fn drop(&mut self) { - if let Some(db) = &self.db { + if let Some(db) = self.db.get() { let f_opt = &self.flush_opt; let db = db.borrow(); let _ = if let Some(cf) = &self.column_family { @@ -1269,7 +1273,7 @@ impl Drop for Rdict { // important, always drop column families first // to ensure that CF handles have shorter life than DB. drop(self.column_family.take()); - drop(self.db.take()); + self.db.close(); } } @@ -1283,7 +1287,7 @@ pub(crate) struct ColumnFamilyPy { // must follow this drop order pub(crate) cf: Arc, // must keep db alive - db: Arc>, + db: DbReferenceHolder, } unsafe impl Send for ColumnFamilyPy {} diff --git a/src/snapshot.rs b/src/snapshot.rs index 6c22205..237bb13 100644 --- a/src/snapshot.rs +++ b/src/snapshot.rs @@ -1,9 +1,10 @@ +use crate::db_reference::{DbReference, DbReferenceHolder}; use crate::encoder::{decode_value, encode_key}; +use crate::exceptions::DbClosedError; use crate::{Rdict, RdictItems, RdictIter, RdictKeys, RdictValues, ReadOptionsPy}; use pyo3::exceptions::PyException; use pyo3::prelude::*; -use rocksdb::{ColumnFamily, ReadOptions, DB}; -use std::cell::RefCell; +use rocksdb::{ColumnFamily, ReadOptions}; use std::ops::Deref; use std::sync::Arc; @@ -43,7 +44,7 @@ pub struct Snapshot { pub(crate) pickle_loads: PyObject, pub(crate) read_opt: ReadOptions, // decrease db Rc last - pub(crate) db: Arc>, + pub(crate) db: DbReferenceHolder, pub(crate) raw_mode: bool, } @@ -133,7 +134,7 @@ impl Snapshot { /// read from snapshot fn __getitem__(&self, key: &PyAny, py: Python) -> PyResult { - let db = self.db.borrow(); + let db = self.get_db().borrow(); let key = encode_key(key, self.raw_mode)?; let value_result = if let Some(cf) = &self.column_family { db.get_pinned_cf_opt(cf.deref(), &key[..], &self.read_opt) @@ -152,33 +153,40 @@ impl Snapshot { impl Snapshot { pub(crate) fn new(rdict: &Rdict, py: Python) -> PyResult { - if let Some(db) = &rdict.db { - let db_borrow = db.borrow(); - let snapshot = unsafe { librocksdb_sys::rocksdb_create_snapshot(db_borrow.inner()) }; - let r_opt: ReadOptions = rdict - .read_opt_py - .to_read_options(rdict.opt_py.raw_mode, py)?; - unsafe { - set_snapshot(r_opt.inner(), snapshot); - } - Ok(Snapshot { - inner: snapshot, - column_family: rdict.column_family.clone(), - pickle_loads: rdict.loads.clone(), - read_opt: r_opt, - db: db.clone(), - raw_mode: rdict.opt_py.raw_mode, - }) - } else { - Err(PyException::new_err("DB already closed")) + let db_inner = rdict + .db + .get() + .ok_or_else(|| DbClosedError::new_err("DB instance already closed"))? + .borrow() + .inner(); + let snapshot = unsafe { librocksdb_sys::rocksdb_create_snapshot(db_inner) }; + let r_opt: ReadOptions = rdict + .read_opt_py + .to_read_options(rdict.opt_py.raw_mode, py)?; + unsafe { + set_snapshot(r_opt.inner(), snapshot); } + Ok(Snapshot { + inner: snapshot, + column_family: rdict.column_family.clone(), + pickle_loads: rdict.loads.clone(), + read_opt: r_opt, + db: rdict.db.clone(), + raw_mode: rdict.opt_py.raw_mode, + }) + } + + fn get_db(&self) -> &DbReference { + self.db + .get() + .expect("Snapshot should never close its DbReference") } } impl Drop for Snapshot { fn drop(&mut self) { unsafe { - librocksdb_sys::rocksdb_release_snapshot(self.db.borrow().inner(), self.inner); + librocksdb_sys::rocksdb_release_snapshot(self.get_db().borrow().inner(), self.inner); } } }