Skip to content

Commit

Permalink
Cancel all background work on last reference drop (#98)
Browse files Browse the repository at this point in the history
This cancels all background work before dropping the last reference to
prevent background work from keeping the database open.
  • Loading branch information
GodTamIt authored Dec 4, 2023
1 parent 4496dd8 commit c01968e
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 49 deletions.
39 changes: 39 additions & 0 deletions src/db_reference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use rocksdb::DB;
use std::cell::RefCell;
use std::sync::Arc;

/// The type of a reference to a [rocksdb::DB] that is passed around the library.
pub(crate) type DbReference = Arc<RefCell<DB>>;

/// A wrapper around [DbReference] that cancels all background work when dropped.
///
/// All users of [rocksdb::DB] should use this wrapper instead to avoid keeping background threads
/// alive after the database is dropped.
#[derive(Clone)]
pub(crate) struct DbReferenceHolder {
inner: Option<DbReference>,
}

impl DbReferenceHolder {
pub fn new(db: DB) -> Self {
Self {
inner: Some(Arc::new(RefCell::new(db))),
}
}

pub fn get(&self) -> Option<&DbReference> {
self.inner.as_ref()
}

pub fn close(&mut self) {
if let Some(db) = self.inner.take().and_then(Arc::into_inner) {
db.borrow_mut().cancel_all_background_work(true);
}
}
}

impl Drop for DbReferenceHolder {
fn drop(&mut self) {
self.close();
}
}
26 changes: 15 additions & 11 deletions src/iter.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use crate::db_reference::DbReferenceHolder;
use crate::encoder::{decode_value, encode_key};
use crate::exceptions::DbClosedError;
use crate::util::error_message;
use crate::{ReadOpt, ReadOptionsPy};
use core::slice;
use libc::{c_char, c_uchar, size_t};
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use rocksdb::{AsColumnFamilyRef, ColumnFamily, DB};
use std::cell::RefCell;
use rocksdb::{AsColumnFamilyRef, ColumnFamily};
use std::ptr::null_mut;
use std::sync::Arc;

#[pyclass]
#[allow(dead_code)]
pub(crate) struct RdictIter {
/// iterator must keep a reference count of DB to keep DB alive.
pub(crate) db: Arc<RefCell<DB>>,
pub(crate) db: DbReferenceHolder,

pub(crate) inner: *mut librocksdb_sys::rocksdb_iterator_t,

Expand Down Expand Up @@ -49,26 +50,29 @@ pub(crate) struct RdictValues {

impl RdictIter {
pub(crate) fn new(
db: &Arc<RefCell<DB>>,
db: &DbReferenceHolder,
cf: &Option<Arc<ColumnFamily>>,
readopts: ReadOptionsPy,
pickle_loads: &PyObject,
raw_mode: bool,
py: Python,
) -> PyResult<Self> {
let readopts = readopts.to_read_opt(raw_mode, py)?;

let db_inner = db
.get()
.ok_or_else(|| DbClosedError::new_err("DB instance already closed"))?
.borrow()
.inner();

Ok(RdictIter {
db: db.clone(),
inner: unsafe {
match cf {
None => {
librocksdb_sys::rocksdb_create_iterator(db.borrow().inner(), readopts.0)
None => librocksdb_sys::rocksdb_create_iterator(db_inner, readopts.0),
Some(cf) => {
librocksdb_sys::rocksdb_create_iterator_cf(db_inner, readopts.0, cf.inner())
}
Some(cf) => librocksdb_sys::rocksdb_create_iterator_cf(
db.borrow().inner(),
readopts.0,
cf.inner(),
),
}
},
readopts,
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// #![feature(core_intrinsics)]
mod db_reference;
mod encoder;
mod exceptions;
mod iter;
Expand Down
32 changes: 18 additions & 14 deletions src/rdict.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::db_reference::{DbReference, DbReferenceHolder};
use crate::encoder::{decode_value, encode_key, encode_value};
use crate::exceptions::DbClosedError;
use crate::iter::{RdictItems, RdictKeys, RdictValues};
Expand Down Expand Up @@ -71,7 +72,7 @@ pub(crate) struct Rdict {
pub(crate) access_type: AccessType,
pub(crate) slice_transforms: Arc<RwLock<HashMap<String, SliceTransformType>>>,
// drop DB last
pub(crate) db: Option<Arc<RefCell<DB>>>,
pub(crate) db: DbReferenceHolder,
}

/// Define DB Access Types.
Expand Down Expand Up @@ -136,9 +137,9 @@ impl Rdict {
.save(config_path)
}

fn get_db(&self) -> PyResult<&Arc<RefCell<DB>>> {
fn get_db(&self) -> PyResult<&DbReference> {
self.db
.as_ref()
.get()
.ok_or_else(|| DbClosedError::new_err("DB instance already closed"))
}
}
Expand Down Expand Up @@ -266,7 +267,7 @@ impl Rdict {
// save rocksdict config
rocksdict_config.save(config_path)?;
Ok(Rdict {
db: Some(Arc::new(RefCell::new(db))),
db: DbReferenceHolder::new(db),
write_opt: (&w_opt).into(),
flush_opt: FlushOptionsPy::new(),
read_opt: r_opt.to_read_options(options.raw_mode, py)?,
Expand Down Expand Up @@ -640,7 +641,7 @@ impl Rdict {
};

RdictIter::new(
self.get_db()?,
&self.db,
&self.column_family,
read_opt,
&self.loads,
Expand Down Expand Up @@ -818,7 +819,7 @@ impl Rdict {
"column name `{name}` does not exist, use `create_cf` to creat it",
))),
Some(cf) => Ok(Self {
db: Some(db.clone()),
db: self.db.clone(),
write_opt: (&self.write_opt_py).into(),
flush_opt: self.flush_opt,
read_opt: self.read_opt_py.to_read_options(self.opt_py.raw_mode, py)?,
Expand Down Expand Up @@ -855,7 +856,10 @@ impl Rdict {
None => Err(PyException::new_err(format!(
"column name `{name}` does not exist, use `create_cf` to creat it",
))),
Some(cf) => Ok(ColumnFamilyPy { cf, db: db.clone() }),
Some(cf) => Ok(ColumnFamilyPy {
cf,
db: self.db.clone(),
}),
}
}

Expand Down Expand Up @@ -1014,8 +1018,8 @@ impl Rdict {
/// Other Column Family `Rdict` instances, `ColumnFamily`
/// (cf handle) instances, iterator instances such as`RdictIter`,
/// `RdictItems`, `RdictKeys`, `RdictValues` can all keep RocksDB
/// alive. `del` all associated instances mentioned above
/// to actually shut down RocksDB.
/// alive. `del` or `close` all associated instances mentioned
/// above to actually shut down RocksDB.
///
fn close(&mut self) -> PyResult<()> {
let f_opt = &self.flush_opt;
Expand All @@ -1024,7 +1028,7 @@ impl Rdict {
AccessTypeInner::ReadOnly { .. } | AccessTypeInner::Secondary { .. } => {
drop(db);
drop(self.column_family.take());
drop(self.db.take());
self.db.close();
return Ok(());
}
_ => (),
Expand All @@ -1037,7 +1041,7 @@ impl Rdict {
};
drop(db);
drop(self.column_family.take());
drop(self.db.take());
self.db.close();
match (flush_result, flush_wal_result) {
(Ok(_), Ok(_)) => Ok(()),
(Err(e), Ok(_)) => Err(PyException::new_err(e.to_string())),
Expand Down Expand Up @@ -1257,7 +1261,7 @@ fn get_batch_inner<'a>(
impl Drop for Rdict {
// flush
fn drop(&mut self) {
if let Some(db) = &self.db {
if let Some(db) = self.db.get() {
let f_opt = &self.flush_opt;
let db = db.borrow();
let _ = if let Some(cf) = &self.column_family {
Expand All @@ -1269,7 +1273,7 @@ impl Drop for Rdict {
// important, always drop column families first
// to ensure that CF handles have shorter life than DB.
drop(self.column_family.take());
drop(self.db.take());
self.db.close();
}
}

Expand All @@ -1283,7 +1287,7 @@ pub(crate) struct ColumnFamilyPy {
// must follow this drop order
pub(crate) cf: Arc<ColumnFamily>,
// must keep db alive
db: Arc<RefCell<DB>>,
db: DbReferenceHolder,
}

unsafe impl Send for ColumnFamilyPy {}
Expand Down
56 changes: 32 additions & 24 deletions src/snapshot.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::db_reference::{DbReference, DbReferenceHolder};
use crate::encoder::{decode_value, encode_key};
use crate::exceptions::DbClosedError;
use crate::{Rdict, RdictItems, RdictIter, RdictKeys, RdictValues, ReadOptionsPy};
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use rocksdb::{ColumnFamily, ReadOptions, DB};
use std::cell::RefCell;
use rocksdb::{ColumnFamily, ReadOptions};
use std::ops::Deref;
use std::sync::Arc;

Expand Down Expand Up @@ -43,7 +44,7 @@ pub struct Snapshot {
pub(crate) pickle_loads: PyObject,
pub(crate) read_opt: ReadOptions,
// decrease db Rc last
pub(crate) db: Arc<RefCell<DB>>,
pub(crate) db: DbReferenceHolder,
pub(crate) raw_mode: bool,
}

Expand Down Expand Up @@ -133,7 +134,7 @@ impl Snapshot {

/// read from snapshot
fn __getitem__(&self, key: &PyAny, py: Python) -> PyResult<PyObject> {
let db = self.db.borrow();
let db = self.get_db().borrow();
let key = encode_key(key, self.raw_mode)?;
let value_result = if let Some(cf) = &self.column_family {
db.get_pinned_cf_opt(cf.deref(), &key[..], &self.read_opt)
Expand All @@ -152,33 +153,40 @@ impl Snapshot {

impl Snapshot {
pub(crate) fn new(rdict: &Rdict, py: Python) -> PyResult<Self> {
if let Some(db) = &rdict.db {
let db_borrow = db.borrow();
let snapshot = unsafe { librocksdb_sys::rocksdb_create_snapshot(db_borrow.inner()) };
let r_opt: ReadOptions = rdict
.read_opt_py
.to_read_options(rdict.opt_py.raw_mode, py)?;
unsafe {
set_snapshot(r_opt.inner(), snapshot);
}
Ok(Snapshot {
inner: snapshot,
column_family: rdict.column_family.clone(),
pickle_loads: rdict.loads.clone(),
read_opt: r_opt,
db: db.clone(),
raw_mode: rdict.opt_py.raw_mode,
})
} else {
Err(PyException::new_err("DB already closed"))
let db_inner = rdict
.db
.get()
.ok_or_else(|| DbClosedError::new_err("DB instance already closed"))?
.borrow()
.inner();
let snapshot = unsafe { librocksdb_sys::rocksdb_create_snapshot(db_inner) };
let r_opt: ReadOptions = rdict
.read_opt_py
.to_read_options(rdict.opt_py.raw_mode, py)?;
unsafe {
set_snapshot(r_opt.inner(), snapshot);
}
Ok(Snapshot {
inner: snapshot,
column_family: rdict.column_family.clone(),
pickle_loads: rdict.loads.clone(),
read_opt: r_opt,
db: rdict.db.clone(),
raw_mode: rdict.opt_py.raw_mode,
})
}

fn get_db(&self) -> &DbReference {
self.db
.get()
.expect("Snapshot should never close its DbReference")
}
}

impl Drop for Snapshot {
fn drop(&mut self) {
unsafe {
librocksdb_sys::rocksdb_release_snapshot(self.db.borrow().inner(), self.inner);
librocksdb_sys::rocksdb_release_snapshot(self.get_db().borrow().inner(), self.inner);
}
}
}
Expand Down

0 comments on commit c01968e

Please sign in to comment.