diff --git a/src/rdict.rs b/src/rdict.rs index 086f799..3723655 100644 --- a/src/rdict.rs +++ b/src/rdict.rs @@ -708,24 +708,27 @@ impl Rdict { /// Args: /// wait (bool): whether to wait for the flush to finish. #[pyo3(signature = (wait = true))] - fn flush(&self, wait: bool) -> PyResult<()> { + fn flush(&self, wait: bool, py: Python) -> PyResult<()> { let db = self.get_db()?; - let mut f_opt = FlushOptions::new(); - f_opt.set_wait(wait); - if let Some(cf) = &self.column_family { - db.flush_cf_opt(cf, &f_opt) - } else { - db.flush_opt(&f_opt) - } + + py.allow_threads(|| { + let mut f_opt = FlushOptions::new(); + f_opt.set_wait(wait); + if let Some(cf) = &self.column_family { + db.flush_cf_opt(cf, &f_opt) + } else { + db.flush_opt(&f_opt) + } + }) .map_err(|e| PyException::new_err(e.into_string())) } /// Flushes the WAL buffer. If `sync` is set to `true`, also syncs /// the data to disk. #[pyo3(signature = (sync = true))] - fn flush_wal(&self, sync: bool) -> PyResult<()> { + fn flush_wal(&self, sync: bool, py: Python) -> PyResult<()> { let db = self.get_db()?; - db.flush_wal(sync) + py.allow_threads(|| db.flush_wal(sync)) .map_err(|e| PyException::new_err(e.into_string())) } @@ -977,25 +980,33 @@ impl Rdict { /// alive. `del` or `close` all associated instances mentioned /// above to actually shut down RocksDB. /// - fn close(&mut self) -> PyResult<()> { + fn close(&mut self, py: Python) -> PyResult<()> { // do not flush if readonly if let AccessTypeInner::ReadOnly { .. } | AccessTypeInner::Secondary { .. } = &self.access_type.0 { - drop(self.column_family.take()); - self.db.close(); + py.allow_threads(|| { + drop(self.column_family.take()); + self.db.close(); + }); return Ok(()); } - let f_opt = &self.flush_opt; - let db = self.get_db()?; - let flush_wal_result = db.flush_wal(true); - let flush_result = if let Some(cf) = &self.column_family { - db.flush_cf_opt(cf, &f_opt.into()) - } else { - db.flush_opt(&f_opt.into()) - }; - drop(self.column_family.take()); - self.db.close(); + + let (flush_wal_result, flush_result) = py.allow_threads(|| { + let f_opt = &self.flush_opt; + let db = self.get_db()?; + + let flush_wal_result = db.flush_wal(true); + let flush_result = if let Some(cf) = &self.column_family { + db.flush_cf_opt(cf, &f_opt.into()) + } else { + db.flush_opt(&f_opt.into()) + }; + drop(self.column_family.take()); + self.db.close(); + + Ok::<_, PyErr>((flush_wal_result, flush_result)) + })?; match (flush_result, flush_wal_result) { (Ok(_), Ok(_)) => Ok(()), (Err(e), Ok(_)) => Err(PyException::new_err(e.to_string())), @@ -1114,9 +1125,14 @@ impl Rdict { /// options (rocksdict.Options): Rocksdb options object #[staticmethod] #[pyo3(signature = (path, options = OptionsPy::new(false)))] - fn destroy(path: &str, options: OptionsPy) -> PyResult<()> { - fs::remove_file(config_file(path)).ok(); - DB::destroy(&options.inner_opt, path).map_err(|e| PyException::new_err(e.to_string())) + fn destroy(path: &str, options: OptionsPy, py: Python) -> PyResult<()> { + let inner_opt = options.inner_opt; + + py.allow_threads(|| { + fs::remove_file(config_file(path)).ok(); + DB::destroy(&inner_opt, path) + }) + .map_err(|e| PyException::new_err(e.to_string())) } /// Repair the database. @@ -1126,8 +1142,11 @@ impl Rdict { /// options (rocksdict.Options): Rocksdb options object #[staticmethod] #[pyo3(signature = (path, options = OptionsPy::new(false)))] - fn repair(path: &str, options: OptionsPy) -> PyResult<()> { - DB::repair(&options.inner_opt, path).map_err(|e| PyException::new_err(e.to_string())) + fn repair(path: &str, options: OptionsPy, py: Python) -> PyResult<()> { + let inner_opt = options.inner_opt; + + py.allow_threads(|| DB::repair(&inner_opt, path)) + .map_err(|e| PyException::new_err(e.to_string())) } #[staticmethod] @@ -1175,7 +1194,7 @@ fn get_batch_inner<'a>( for key in key_list { keys.push(encode_key(key, raw_mode)?); } - let values = db.batched_multi_get_cf_opt(cf, &keys, false, read_opt); + let values = py.allow_threads(|| db.batched_multi_get_cf_opt(cf, &keys, false, read_opt)); let result = PyList::empty(py); for v in values { match v { diff --git a/test/bench_rdict.py b/test/bench_rdict.py index 64d56e1..8daf36e 100644 --- a/test/bench_rdict.py +++ b/test/bench_rdict.py @@ -1,140 +1,183 @@ -from rocksdict import Rdict, Options +from rocksdict import Rdict, Options, WriteBatch, WriteOptions from random import randbytes from threading import Thread +from typing import List import time -def gen_rand_bytes(): - return [randbytes(128) for i in range(1024 * 1024)] +def gen_rand_bytes() -> List[bytes]: + return [randbytes(128) for _ in range(4 * 1024 * 1024)] -def perf_put_single_thread(rand_bytes): - rdict = Rdict('test.db', Options(raw_mode=True)) - start = time.time() +def perf_put_single_thread(rand_bytes: List[bytes]): + rdict = Rdict("test.db", Options(raw_mode=True)) + + batch = WriteBatch(raw_mode=True) for k in rand_bytes: - rdict[k] = k - end = time.time() - print('Put performance: {} items in {} seconds'.format(len(rand_bytes), end - start)) + batch.put(k, k) + + start = time.perf_counter() + # Make the write sync so this doesn't just benchmark system cache state. + write_opt = WriteOptions() + write_opt.sync = True + rdict.write(batch, write_opt=write_opt) + end = time.perf_counter() + print( + "Put performance: {} items in {} seconds".format(len(rand_bytes), end - start) + ) count = 0 for k, v in rdict.items(): assert k == v count += 1 - assert count == len(rand_bytes) + assert count == len(rand_bytes), f"{count=} != {len(rand_bytes)}" rdict.close() - Rdict.destroy('test.db') + Rdict.destroy("test.db") + +def perf_put_multi_thread(rand_bytes: List[bytes], num_threads: int): + rdict = Rdict("test.db", Options(raw_mode=True)) + + def perf_put(batch: WriteBatch): + # Make the write sync so this doesn't just benchmark system cache state. + write_opt = WriteOptions() + write_opt.sync = True + rdict.write(batch, write_opt=write_opt) -def perf_put_multi_thread(rand_bytes): - rdict = Rdict('test.db', Options(raw_mode=True)) - start = time.time() - THREAD = 4 - def perf_put(dat): - for k in dat: - rdict[k] = k threads = [] - each_len = len(rand_bytes) // THREAD - for i in range(THREAD): - t = Thread(target=perf_put, args=(rand_bytes[i*each_len:(i+1)*each_len],)) + each_len = len(rand_bytes) // num_threads + batches = [] + for i in range(num_threads): + batch = WriteBatch(raw_mode=True) + for val in rand_bytes[i * each_len : (i + 1) * each_len]: + batch.put(val, val) + batches.append(batch) + + start = time.perf_counter() + for batch in batches: + t = Thread(target=perf_put, args=(batch,)) t.start() threads.append(t) for t in threads: t.join() - end = time.time() - print('Put performance multi-thread: {} items in {} seconds'.format(len(rand_bytes), end - start)) + end = time.perf_counter() + print( + "Put performance multi-thread: {} items in {} seconds".format( + len(rand_bytes), end - start + ) + ) + count = 0 for k, v in rdict.items(): assert k == v count += 1 - assert count == len(rand_bytes) + assert count == len(rand_bytes), f"{count=} != {len(rand_bytes)}" rdict.close() - Rdict.destroy('test.db') + Rdict.destroy("test.db") -def perf_iterator_single_thread(rand_bytes): - rdict = Rdict('test.db', Options(raw_mode=True)) - start = time.time() +def perf_iterator_single_thread(rand_bytes: List[bytes]): + rdict = Rdict("test.db", Options(raw_mode=True)) + start = time.perf_counter() count = 0 for k, v in rdict.items(): assert k == v count += 1 - end = time.time() + end = time.perf_counter() assert count == len(rand_bytes) - print('Iterator performance: {} items in {} seconds'.format(count, end - start)) + print("Iterator performance: {} items in {} seconds".format(count, end - start)) rdict.close() -def perf_iterator_multi_thread(rand_bytes): - rdict = Rdict('test.db', Options(raw_mode=True)) - start = time.time() - THREAD = 4 +def perf_iterator_multi_thread(rand_bytes: List[bytes], num_threads: int): + rdict = Rdict("test.db", Options(raw_mode=True)) + start = time.perf_counter() + def perf_iter(): count = 0 for k, v in rdict.items(): assert k == v count += 1 assert count == len(rand_bytes) + threads = [] - for _ in range(THREAD): + for _ in range(num_threads): t = Thread(target=perf_iter) t.start() threads.append(t) for t in threads: t.join() - end = time.time() - print('Iterator performance multi-thread: {} items in {} seconds'.format(THREAD * len(rand_bytes), end - start)) + end = time.perf_counter() + print( + "Iterator performance multi-thread: {} items in {} seconds".format( + num_threads * len(rand_bytes), end - start + ) + ) rdict.close() -def perf_random_get_single_thread(rand_bytes): - rdict = Rdict('test.db', Options(raw_mode=True)) - start = time.time() - for k in rand_bytes: - assert k == rdict[k] - end = time.time() - print('Get performance: {} items in {} seconds'.format(len(rand_bytes), end - start)) +def perf_random_get_single_thread(rand_bytes: List[bytes]): + rdict = Rdict("test.db", Options(raw_mode=True)) + start = time.perf_counter() + vals = rdict.get(rand_bytes) + for key, val in zip(rand_bytes, vals): + assert key == val + end = time.perf_counter() + print( + "Get performance: {} items in {} seconds".format(len(rand_bytes), end - start) + ) rdict.close() -def perf_random_get_multi_thread(rand_bytes): - rdict = Rdict('test.db', Options(raw_mode=True)) - start = time.time() - THREAD = 4 - def perf_get(dat): - for k in dat: - assert k == rdict[k] +def perf_random_get_multi_thread(rand_bytes: List[bytes], num_threads: int): + rdict = Rdict("test.db", Options(raw_mode=True)) + start = time.perf_counter() + + def perf_get(keys: List[bytes]): + vals = rdict.get(keys) + for key, val in zip(keys, vals): + assert key == val + threads = [] - each_len = len(rand_bytes) // THREAD - for i in range(THREAD): - t = Thread(target=perf_get, args=(rand_bytes[i*each_len:(i+1)*each_len],)) + each_len = len(rand_bytes) // num_threads + for i in range(num_threads): + t = Thread( + target=perf_get, args=(rand_bytes[i * each_len : (i + 1) * each_len],) + ) t.start() threads.append(t) for t in threads: t.join() - end = time.time() - print('Get performance multi-thread: {} items in {} seconds'.format(len(rand_bytes), end - start)) + end = time.perf_counter() + print( + "Get performance multi-thread: {} items in {} seconds".format( + len(rand_bytes), end - start + ) + ) rdict.close() -if __name__ == '__main__': - print('Gen rand bytes...') +if __name__ == "__main__": + print("Gen rand bytes...") rand_bytes = gen_rand_bytes() - print('Benchmarking Rdict Put...') + NUM_THREADS = 4 + + print("Benchmarking Rdict Put...") # perf write perf_put_single_thread(rand_bytes) - perf_put_multi_thread(rand_bytes) + perf_put_multi_thread(rand_bytes, num_threads=NUM_THREADS) # Create a new Rdict instance - rdict = Rdict('test.db', Options(raw_mode=True)) + rdict = Rdict("test.db", Options(raw_mode=True)) for b in rand_bytes: rdict[b] = b rdict.close() - print('Benchmarking Rdict Iterator...') + print("Benchmarking Rdict Iterator...") perf_iterator_single_thread(rand_bytes) - perf_iterator_multi_thread(rand_bytes) - print('Benchmarking Rdict Get...') + perf_iterator_multi_thread(rand_bytes, num_threads=NUM_THREADS) + print("Benchmarking Rdict Get...") perf_random_get_single_thread(rand_bytes) - perf_random_get_multi_thread(rand_bytes) + perf_random_get_multi_thread(rand_bytes, num_threads=NUM_THREADS) # Destroy the Rdict instance - Rdict.destroy('test.db') + Rdict.destroy("test.db")