Skip to content

Commit

Permalink
Merge branch 'main' into update-upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
Congyuwang committed Dec 14, 2023
2 parents ac73f49 + 0501998 commit 02ae87c
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 95 deletions.
77 changes: 48 additions & 29 deletions src/rdict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,24 +708,27 @@ impl Rdict {
/// Args:
/// wait (bool): whether to wait for the flush to finish.
#[pyo3(signature = (wait = true))]
fn flush(&self, wait: bool) -> PyResult<()> {
fn flush(&self, wait: bool, py: Python) -> PyResult<()> {
let db = self.get_db()?;
let mut f_opt = FlushOptions::new();
f_opt.set_wait(wait);
if let Some(cf) = &self.column_family {
db.flush_cf_opt(cf, &f_opt)
} else {
db.flush_opt(&f_opt)
}

py.allow_threads(|| {
let mut f_opt = FlushOptions::new();
f_opt.set_wait(wait);
if let Some(cf) = &self.column_family {
db.flush_cf_opt(cf, &f_opt)
} else {
db.flush_opt(&f_opt)
}
})
.map_err(|e| PyException::new_err(e.into_string()))
}

/// Flushes the WAL buffer. If `sync` is set to `true`, also syncs
/// the data to disk.
#[pyo3(signature = (sync = true))]
fn flush_wal(&self, sync: bool) -> PyResult<()> {
fn flush_wal(&self, sync: bool, py: Python) -> PyResult<()> {
let db = self.get_db()?;
db.flush_wal(sync)
py.allow_threads(|| db.flush_wal(sync))
.map_err(|e| PyException::new_err(e.into_string()))
}

Expand Down Expand Up @@ -977,25 +980,33 @@ impl Rdict {
/// alive. `del` or `close` all associated instances mentioned
/// above to actually shut down RocksDB.
///
fn close(&mut self) -> PyResult<()> {
fn close(&mut self, py: Python) -> PyResult<()> {
// do not flush if readonly
if let AccessTypeInner::ReadOnly { .. } | AccessTypeInner::Secondary { .. } =
&self.access_type.0
{
drop(self.column_family.take());
self.db.close();
py.allow_threads(|| {
drop(self.column_family.take());
self.db.close();
});
return Ok(());
}
let f_opt = &self.flush_opt;
let db = self.get_db()?;
let flush_wal_result = db.flush_wal(true);
let flush_result = if let Some(cf) = &self.column_family {
db.flush_cf_opt(cf, &f_opt.into())
} else {
db.flush_opt(&f_opt.into())
};
drop(self.column_family.take());
self.db.close();

let (flush_wal_result, flush_result) = py.allow_threads(|| {
let f_opt = &self.flush_opt;
let db = self.get_db()?;

let flush_wal_result = db.flush_wal(true);
let flush_result = if let Some(cf) = &self.column_family {
db.flush_cf_opt(cf, &f_opt.into())
} else {
db.flush_opt(&f_opt.into())
};
drop(self.column_family.take());
self.db.close();

Ok::<_, PyErr>((flush_wal_result, flush_result))
})?;
match (flush_result, flush_wal_result) {
(Ok(_), Ok(_)) => Ok(()),
(Err(e), Ok(_)) => Err(PyException::new_err(e.to_string())),
Expand Down Expand Up @@ -1114,9 +1125,14 @@ impl Rdict {
/// options (rocksdict.Options): Rocksdb options object
#[staticmethod]
#[pyo3(signature = (path, options = OptionsPy::new(false)))]
fn destroy(path: &str, options: OptionsPy) -> PyResult<()> {
fs::remove_file(config_file(path)).ok();
DB::destroy(&options.inner_opt, path).map_err(|e| PyException::new_err(e.to_string()))
fn destroy(path: &str, options: OptionsPy, py: Python) -> PyResult<()> {
let inner_opt = options.inner_opt;

py.allow_threads(|| {
fs::remove_file(config_file(path)).ok();
DB::destroy(&inner_opt, path)
})
.map_err(|e| PyException::new_err(e.to_string()))
}

/// Repair the database.
Expand All @@ -1126,8 +1142,11 @@ impl Rdict {
/// options (rocksdict.Options): Rocksdb options object
#[staticmethod]
#[pyo3(signature = (path, options = OptionsPy::new(false)))]
fn repair(path: &str, options: OptionsPy) -> PyResult<()> {
DB::repair(&options.inner_opt, path).map_err(|e| PyException::new_err(e.to_string()))
fn repair(path: &str, options: OptionsPy, py: Python) -> PyResult<()> {
let inner_opt = options.inner_opt;

py.allow_threads(|| DB::repair(&inner_opt, path))
.map_err(|e| PyException::new_err(e.to_string()))
}

#[staticmethod]
Expand Down Expand Up @@ -1175,7 +1194,7 @@ fn get_batch_inner<'a>(
for key in key_list {
keys.push(encode_key(key, raw_mode)?);
}
let values = db.batched_multi_get_cf_opt(cf, &keys, false, read_opt);
let values = py.allow_threads(|| db.batched_multi_get_cf_opt(cf, &keys, false, read_opt));
let result = PyList::empty(py);
for v in values {
match v {
Expand Down
175 changes: 109 additions & 66 deletions test/bench_rdict.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,183 @@
from rocksdict import Rdict, Options
from rocksdict import Rdict, Options, WriteBatch, WriteOptions
from random import randbytes
from threading import Thread
from typing import List
import time


def gen_rand_bytes():
return [randbytes(128) for i in range(1024 * 1024)]
def gen_rand_bytes() -> List[bytes]:
return [randbytes(128) for _ in range(4 * 1024 * 1024)]


def perf_put_single_thread(rand_bytes):
rdict = Rdict('test.db', Options(raw_mode=True))
start = time.time()
def perf_put_single_thread(rand_bytes: List[bytes]):
rdict = Rdict("test.db", Options(raw_mode=True))

batch = WriteBatch(raw_mode=True)
for k in rand_bytes:
rdict[k] = k
end = time.time()
print('Put performance: {} items in {} seconds'.format(len(rand_bytes), end - start))
batch.put(k, k)

start = time.perf_counter()
# Make the write sync so this doesn't just benchmark system cache state.
write_opt = WriteOptions()
write_opt.sync = True
rdict.write(batch, write_opt=write_opt)
end = time.perf_counter()
print(
"Put performance: {} items in {} seconds".format(len(rand_bytes), end - start)
)
count = 0
for k, v in rdict.items():
assert k == v
count += 1
assert count == len(rand_bytes)
assert count == len(rand_bytes), f"{count=} != {len(rand_bytes)}"
rdict.close()
Rdict.destroy('test.db')
Rdict.destroy("test.db")


def perf_put_multi_thread(rand_bytes: List[bytes], num_threads: int):
rdict = Rdict("test.db", Options(raw_mode=True))

def perf_put(batch: WriteBatch):
# Make the write sync so this doesn't just benchmark system cache state.
write_opt = WriteOptions()
write_opt.sync = True
rdict.write(batch, write_opt=write_opt)

def perf_put_multi_thread(rand_bytes):
rdict = Rdict('test.db', Options(raw_mode=True))
start = time.time()
THREAD = 4
def perf_put(dat):
for k in dat:
rdict[k] = k
threads = []
each_len = len(rand_bytes) // THREAD
for i in range(THREAD):
t = Thread(target=perf_put, args=(rand_bytes[i*each_len:(i+1)*each_len],))
each_len = len(rand_bytes) // num_threads
batches = []
for i in range(num_threads):
batch = WriteBatch(raw_mode=True)
for val in rand_bytes[i * each_len : (i + 1) * each_len]:
batch.put(val, val)
batches.append(batch)

start = time.perf_counter()
for batch in batches:
t = Thread(target=perf_put, args=(batch,))
t.start()
threads.append(t)
for t in threads:
t.join()
end = time.time()
print('Put performance multi-thread: {} items in {} seconds'.format(len(rand_bytes), end - start))
end = time.perf_counter()
print(
"Put performance multi-thread: {} items in {} seconds".format(
len(rand_bytes), end - start
)
)

count = 0
for k, v in rdict.items():
assert k == v
count += 1
assert count == len(rand_bytes)
assert count == len(rand_bytes), f"{count=} != {len(rand_bytes)}"
rdict.close()
Rdict.destroy('test.db')
Rdict.destroy("test.db")


def perf_iterator_single_thread(rand_bytes):
rdict = Rdict('test.db', Options(raw_mode=True))
start = time.time()
def perf_iterator_single_thread(rand_bytes: List[bytes]):
rdict = Rdict("test.db", Options(raw_mode=True))
start = time.perf_counter()
count = 0
for k, v in rdict.items():
assert k == v
count += 1
end = time.time()
end = time.perf_counter()
assert count == len(rand_bytes)
print('Iterator performance: {} items in {} seconds'.format(count, end - start))
print("Iterator performance: {} items in {} seconds".format(count, end - start))
rdict.close()


def perf_iterator_multi_thread(rand_bytes):
rdict = Rdict('test.db', Options(raw_mode=True))
start = time.time()
THREAD = 4
def perf_iterator_multi_thread(rand_bytes: List[bytes], num_threads: int):
rdict = Rdict("test.db", Options(raw_mode=True))
start = time.perf_counter()

def perf_iter():
count = 0
for k, v in rdict.items():
assert k == v
count += 1
assert count == len(rand_bytes)

threads = []
for _ in range(THREAD):
for _ in range(num_threads):
t = Thread(target=perf_iter)
t.start()
threads.append(t)
for t in threads:
t.join()
end = time.time()
print('Iterator performance multi-thread: {} items in {} seconds'.format(THREAD * len(rand_bytes), end - start))
end = time.perf_counter()
print(
"Iterator performance multi-thread: {} items in {} seconds".format(
num_threads * len(rand_bytes), end - start
)
)
rdict.close()


def perf_random_get_single_thread(rand_bytes):
rdict = Rdict('test.db', Options(raw_mode=True))
start = time.time()
for k in rand_bytes:
assert k == rdict[k]
end = time.time()
print('Get performance: {} items in {} seconds'.format(len(rand_bytes), end - start))
def perf_random_get_single_thread(rand_bytes: List[bytes]):
rdict = Rdict("test.db", Options(raw_mode=True))
start = time.perf_counter()
vals = rdict.get(rand_bytes)
for key, val in zip(rand_bytes, vals):
assert key == val
end = time.perf_counter()
print(
"Get performance: {} items in {} seconds".format(len(rand_bytes), end - start)
)
rdict.close()


def perf_random_get_multi_thread(rand_bytes):
rdict = Rdict('test.db', Options(raw_mode=True))
start = time.time()
THREAD = 4
def perf_get(dat):
for k in dat:
assert k == rdict[k]
def perf_random_get_multi_thread(rand_bytes: List[bytes], num_threads: int):
rdict = Rdict("test.db", Options(raw_mode=True))
start = time.perf_counter()

def perf_get(keys: List[bytes]):
vals = rdict.get(keys)
for key, val in zip(keys, vals):
assert key == val

threads = []
each_len = len(rand_bytes) // THREAD
for i in range(THREAD):
t = Thread(target=perf_get, args=(rand_bytes[i*each_len:(i+1)*each_len],))
each_len = len(rand_bytes) // num_threads
for i in range(num_threads):
t = Thread(
target=perf_get, args=(rand_bytes[i * each_len : (i + 1) * each_len],)
)
t.start()
threads.append(t)
for t in threads:
t.join()
end = time.time()
print('Get performance multi-thread: {} items in {} seconds'.format(len(rand_bytes), end - start))
end = time.perf_counter()
print(
"Get performance multi-thread: {} items in {} seconds".format(
len(rand_bytes), end - start
)
)
rdict.close()


if __name__ == '__main__':
print('Gen rand bytes...')
if __name__ == "__main__":
print("Gen rand bytes...")
rand_bytes = gen_rand_bytes()

print('Benchmarking Rdict Put...')
NUM_THREADS = 4

print("Benchmarking Rdict Put...")
# perf write
perf_put_single_thread(rand_bytes)
perf_put_multi_thread(rand_bytes)
perf_put_multi_thread(rand_bytes, num_threads=NUM_THREADS)

# Create a new Rdict instance
rdict = Rdict('test.db', Options(raw_mode=True))
rdict = Rdict("test.db", Options(raw_mode=True))
for b in rand_bytes:
rdict[b] = b
rdict.close()
print('Benchmarking Rdict Iterator...')
print("Benchmarking Rdict Iterator...")
perf_iterator_single_thread(rand_bytes)
perf_iterator_multi_thread(rand_bytes)
print('Benchmarking Rdict Get...')
perf_iterator_multi_thread(rand_bytes, num_threads=NUM_THREADS)
print("Benchmarking Rdict Get...")
perf_random_get_single_thread(rand_bytes)
perf_random_get_multi_thread(rand_bytes)
perf_random_get_multi_thread(rand_bytes, num_threads=NUM_THREADS)

# Destroy the Rdict instance
Rdict.destroy('test.db')
Rdict.destroy("test.db")

0 comments on commit 02ae87c

Please sign in to comment.