Skip to content

Commit

Permalink
Updates from orm branch (#36)
Browse files Browse the repository at this point in the history
* improve sessions

* improve tests
  • Loading branch information
mkrd authored Nov 17, 2022
1 parent ae732a9 commit 755d7b6
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 136 deletions.
15 changes: 12 additions & 3 deletions dictdatabase/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import TypeVar, Any, Callable
from . import utils, io_safe, config
from . session import DDBSession
from . sessions import SessionFileFull, SessionFileKey, SessionFileWhere, SessionDirFull, SessionDirWhere

T = TypeVar("T")

Expand Down Expand Up @@ -200,7 +200,7 @@ def type_cast(value):
return type_cast(data)


def session(self, as_type: T = None) -> DDBSession[T]:
def session(self, as_type: T = None) -> SessionFileFull[T] | SessionFileKey[T] | SessionFileWhere[T] | SessionDirFull[T] | SessionDirWhere[T]:
"""
Opens a session to the selected file(s) or folder, depending on previous
`.at(...)` selection. Inside the with block, you have exclusive access
Expand All @@ -215,4 +215,13 @@ def session(self, as_type: T = None) -> DDBSession[T]:
- `FileNotFoundError`: If the file does not exist.
- `KeyError`: If a key is specified and it does not exist.
"""
return DDBSession(self.path, self.op_type, self.key, self.where, as_type)
if self.op_type.file_normal:
return SessionFileFull(self.path, as_type)
if self.op_type.file_key:
return SessionFileKey(self.path, self.key, as_type)
if self.op_type.file_where:
return SessionFileWhere(self.path, self.where, as_type)
if self.op_type.dir_normal:
return SessionDirFull(self.path, as_type)
if self.op_type.dir_where:
return SessionDirWhere(self.path, self.where, as_type)
130 changes: 0 additions & 130 deletions dictdatabase/session.py

This file was deleted.

211 changes: 211 additions & 0 deletions dictdatabase/sessions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from __future__ import annotations
from typing import Tuple, TypeVar, Generic, Any, Callable
from . import utils, io_unsafe, locking

from contextlib import contextmanager


T = TypeVar("T")
JSONSerializable = TypeVar("JSONSerializable", str, int, float, bool, None, list, dict)



def type_cast(obj, as_type):
return obj if as_type is None else as_type(obj)



class SessionBase:
in_session: bool
db_name: str
as_type: T

def __init__(self, db_name: str, as_type):
self.in_session = False
self.db_name = db_name
self.as_type = as_type

def __enter__(self):
self.in_session = True
self.data_handle = {}

def __exit__(self, type, value, tb):
write_lock = getattr(self, "write_lock", None)
if write_lock is not None:
if isinstance(write_lock, list):
for lock in write_lock:
lock._unlock()
else:
write_lock._unlock()
self.write_lock, self.in_session = None, False

def write(self):
if not self.in_session:
raise PermissionError("Only call write() inside a with statement.")



@contextmanager
def safe_context(super, self, *, db_names_to_lock=None):
"""
If an exception happens in the context, the __exit__ method of the passed super
class will be called.
"""
super.__enter__()
try:
if isinstance(db_names_to_lock, str):
self.write_lock = locking.WriteLock(self.db_name)
self.write_lock._lock()
elif isinstance(db_names_to_lock, list):
self.write_lock = [locking.WriteLock(x) for x in self.db_name]
for lock in self.write_lock:
lock._lock()
yield
except BaseException as e:
super.__exit__(type(e), e, e.__traceback__)
raise e



########################################################################################
#### File sessions
########################################################################################



class SessionFileFull(SessionBase, Generic[T]):
"""
Context manager for read-write access to a full file.
Efficiency:
Reads and writes the entire file.
"""

def __enter__(self) -> Tuple[SessionFileFull, JSONSerializable | T]:
with safe_context(super(), self, db_names_to_lock=self.db_name):
self.data_handle = io_unsafe.read(self.db_name)
return self, type_cast(self.data_handle, self.as_type)

def write(self):
super().write()
io_unsafe.write(self.db_name, self.data_handle)



class SessionFileKey(SessionBase, Generic[T]):
"""
Context manager for read-write access to a single key-value item in a file.
Efficiency:
Uses partial reading, which allows only reading the bytes of the key-value item.
When writing, only the bytes of the key-value and the bytes of the file after
the key-value are written.
"""

def __init__(self, db_name: str, key: str, as_type: T):
super().__init__(db_name, as_type)
self.key = key

def __enter__(self) -> Tuple[SessionFileKey, JSONSerializable | T]:
with safe_context(super(), self, db_names_to_lock=self.db_name):
self.partial_handle = io_unsafe.get_partial_file_handle(self.db_name, self.key)
self.data_handle = self.partial_handle.partial_dict.value
return self, type_cast(self.data_handle, self.as_type)

def write(self):
super().write()
io_unsafe.partial_write(self.partial_handle)



class SessionFileWhere(SessionBase, Generic[T]):
"""
Context manager for read-write access to selection of key-value items in a file.
The where callable is called with the key and value of each item in the file.
Efficiency:
Reads and writes the entire file, so it is not more efficient than
SessionFileFull.
"""
def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T):
super().__init__(db_name, as_type)
self.where = where

def __enter__(self) -> Tuple[SessionFileWhere, JSONSerializable | T]:
with safe_context(super(), self, db_names_to_lock=self.db_name):
self.original_data = io_unsafe.read(self.db_name)
for k, v in self.original_data.items():
if self.where(k, v):
self.data_handle[k] = v
return self, type_cast(self.data_handle, self.as_type)

def write(self):
super().write()
self.original_data.update(self.data_handle)
io_unsafe.write(self.db_name, self.original_data)



########################################################################################
#### File sessions
########################################################################################



class SessionDirFull(SessionBase, Generic[T]):
"""
Context manager for read-write access to all files in a directory.
They are provided as a dict of {str(file_name): dict(file_content)}, where the
file name does not contain the directory name nor the file extension.
Efficiency:
Fully reads and writes all files.
"""
def __init__(self, db_name: str, as_type: T):
super().__init__(utils.find_all(db_name), as_type)

def __enter__(self) -> Tuple[SessionDirFull, JSONSerializable | T]:
with safe_context(super(), self, db_names_to_lock=self.db_name):
self.data_handle = {n.split("/")[-1]: io_unsafe.read(n) for n in self.db_name}
return self, type_cast(self.data_handle, self.as_type)

def write(self):
super().write()
for name in self.db_name:
io_unsafe.write(name, self.data_handle[name.split("/")[-1]])



class SessionDirWhere(SessionBase, Generic[T]):
"""
Context manager for read-write access to selection of files in a directory.
The where callable is called with the file name and parsed content of each file.
Efficiency:
Fully reads all files, but only writes the selected files.
"""
def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T):
super().__init__(utils.find_all(db_name), as_type)
self.where = where

def __enter__(self) -> Tuple[SessionDirWhere, JSONSerializable | T]:
with safe_context(super(), self):
selected_db_names, write_lock = [], []
for db_name in self.db_name:
lock = locking.WriteLock(db_name)
lock._lock()
k, v = db_name.split("/")[-1], io_unsafe.read(db_name)
if self.where(k, v):
self.data_handle[k] = v
write_lock.append(lock)
selected_db_names.append(db_name)
else:
lock._unlock()
self.write_lock = write_lock
self.db_name = selected_db_names
return self, type_cast(self.data_handle, self.as_type)

def write(self):
super().write()
for name in self.db_name:
io_unsafe.write(name, self.data_handle[name.split("/")[-1]])
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@
from tests import TEST_DIR
import pytest
import shutil
import os


@pytest.fixture(scope="session")
def use_test_dir(request):
DDB.config.storage_directory = TEST_DIR
os.makedirs(TEST_DIR, exist_ok=True)
request.addfinalizer(lambda: shutil.rmtree(TEST_DIR))



@pytest.fixture(scope="function")
def name_of_test(request):
return request.function.__name__



@pytest.fixture(params=[True, False])
def use_compression(request):
DDB.config.use_compression = request.param
Expand Down
Loading

0 comments on commit 755d7b6

Please sign in to comment.