diff --git a/DictDataBase.code-workspace b/DictDataBase.code-workspace index 3f9af16..2b95288 100644 --- a/DictDataBase.code-workspace +++ b/DictDataBase.code-workspace @@ -5,6 +5,12 @@ } ], "settings": { - "python.pythonPath": ".venv/bin/python3" + "[python]": { + "editor.formatOnSave": true, + "editor.defaultFormatter": "charliermarsh.ruff" + }, + "editor.codeActionsOnSave": { + "source.organizeImports": true + }, } } diff --git a/dictdatabase/configuration.py b/dictdatabase/configuration.py index 6a69d41..433ba0e 100644 --- a/dictdatabase/configuration.py +++ b/dictdatabase/configuration.py @@ -2,7 +2,6 @@ class Confuguration: - __slots__ = ("storage_directory", "indent", "use_compression", "use_orjson") storage_directory: str diff --git a/dictdatabase/indexing.py b/dictdatabase/indexing.py index df719d5..39dd933 100644 --- a/dictdatabase/indexing.py +++ b/dictdatabase/indexing.py @@ -26,19 +26,19 @@ class Indexer: """ - The Indexer takes the name of a database file, and tries to load the .index file - of the corresponding database file. - - The name of the index file is the name of the database file, with the extension - .index and all "/" replaced with "___" - - The content of the index file is a json object, where the keys are keys inside - the database json file, and the values are lists of 5 elements: - - start_index: The index of the first byte of the value of the key in the database file - - end_index: The index of the last byte of the value of the key in the database file - - indent_level: The indent level of the key in the database file - - indent_with: The indent string used. - - value_hash: The hash of the value bytes + The Indexer takes the name of a database file, and tries to load the .index file + of the corresponding database file. + + The name of the index file is the name of the database file, with the extension + .index and all "/" replaced with "___" + + The content of the index file is a json object, where the keys are keys inside + the database json file, and the values are lists of 5 elements: + - start_index: The index of the first byte of the value of the key in the database file + - end_index: The index of the last byte of the value of the key in the database file + - indent_level: The indent level of the key in the database file + - indent_with: The indent string used. + - value_hash: The hash of the value bytes """ __slots__ = ("data", "path") @@ -59,15 +59,13 @@ def __init__(self, db_name: str) -> None: except orjson.JSONDecodeError: self.data = {} - def get(self, key: str) -> Union[list, None]: """ - Returns a list of 5 elements for a key if it exists, otherwise None - Elements:[start_index, end_index, indent_level, indent_with, value_hash] + Returns a list of 5 elements for a key if it exists, otherwise None + Elements:[start_index, end_index, indent_level, indent_with, value_hash] """ return self.data.get(key, None) - def write( self, key: str, @@ -79,7 +77,7 @@ def write( old_value_end: int, ) -> None: """ - Write index information for a key to the index file + Write index information for a key to the index file """ if self.data.get(key, None) is not None: diff --git a/dictdatabase/io_bytes.py b/dictdatabase/io_bytes.py index 2ccf472..276a0c3 100644 --- a/dictdatabase/io_bytes.py +++ b/dictdatabase/io_bytes.py @@ -6,37 +6,34 @@ def read(db_name: str, *, start: int = None, end: int = None) -> bytes: """ - Read the content of a file as bytes. Reading works even when the config - changes, so a compressed ddb file can also be read if compression is - disabled, and vice versa. - - If no compression is used, efficient reading can be done by specifying a start - and end byte index, such that only the bytes in that range are read from the - file. - - If compression is used, specifying a start and end byte index is still possible, - but the entire file has to be read and decompressed first, and then the bytes - in the range are returned. This is because the compressed file is not seekable. - - Args: - - `db_name`: The name of the database file to read from. - - `start`: The start byte index to read from. - - `end`: The end byte index to read up to (not included). - - Raises: - - `FileNotFoundError`: If the file does not exist as .json nor .ddb. - - `OSError`: If no compression is used and `start` is negative. - - `FileExistsError`: If the file exists as .json and .ddb. + Read the content of a file as bytes. Reading works even when the config + changes, so a compressed ddb file can also be read if compression is + disabled, and vice versa. + + If no compression is used, efficient reading can be done by specifying a start + and end byte index, such that only the bytes in that range are read from the + file. + + If compression is used, specifying a start and end byte index is still possible, + but the entire file has to be read and decompressed first, and then the bytes + in the range are returned. This is because the compressed file is not seekable. + + Args: + - `db_name`: The name of the database file to read from. + - `start`: The start byte index to read from. + - `end`: The end byte index to read up to (not included). + + Raises: + - `FileNotFoundError`: If the file does not exist as .json nor .ddb. + - `OSError`: If no compression is used and `start` is negative. + - `FileExistsError`: If the file exists as .json and .ddb. """ json_path, json_exists, ddb_path, ddb_exists = utils.file_info(db_name) if json_exists: if ddb_exists: - raise FileExistsError( - f"Inconsistent: \"{db_name}\" exists as .json and .ddb." - "Please remove one of them." - ) + raise FileExistsError(f'Inconsistent: "{db_name}" exists as .json and .ddb.' "Please remove one of them.") with open(json_path, "rb") as f: if start is None and end is None: return f.read() @@ -46,7 +43,7 @@ def read(db_name: str, *, start: int = None, end: int = None) -> bytes: return f.read() return f.read(end - start) if not ddb_exists: - raise FileNotFoundError(f"No database file exists for \"{db_name}\"") + raise FileNotFoundError(f'No database file exists for "{db_name}"') with open(ddb_path, "rb") as f: json_bytes = zlib.decompress(f.read()) if start is None and end is None: @@ -56,19 +53,17 @@ def read(db_name: str, *, start: int = None, end: int = None) -> bytes: return json_bytes[start:end] - - def write(db_name: str, dump: bytes, *, start: int = None) -> None: """ - Write the bytes to the file of the db_path. If the db was compressed but no - compression is enabled, remove the compressed file, and vice versa. - - Args: - - `db_name`: The name of the database to write to. - - `dump`: The bytes to write to the file, representing correct JSON when - decoded. - - `start`: The start byte index to write to. If None, the whole file is overwritten. - If the original content was longer, the rest truncated. + Write the bytes to the file of the db_path. If the db was compressed but no + compression is enabled, remove the compressed file, and vice versa. + + Args: + - `db_name`: The name of the database to write to. + - `dump`: The bytes to write to the file, representing correct JSON when + decoded. + - `start`: The start byte index to write to. If None, the whole file is overwritten. + If the original content was longer, the rest truncated. """ json_path, json_exists, ddb_path, ddb_exists = utils.file_info(db_name) diff --git a/dictdatabase/io_safe.py b/dictdatabase/io_safe.py index dae9699..d6586a0 100644 --- a/dictdatabase/io_safe.py +++ b/dictdatabase/io_safe.py @@ -5,10 +5,10 @@ def read(file_name: str) -> dict: """ - Read the content of a file as a dict. + Read the content of a file as a dict. - Args: - - `file_name`: The name of the file to read from. + Args: + - `file_name`: The name of the file to read from. """ _, json_exists, _, ddb_exists = utils.file_info(file_name) @@ -20,14 +20,13 @@ def read(file_name: str) -> dict: return io_unsafe.read(file_name) - def partial_read(file_name: str, key: str) -> dict: """ - Read only the value of a key-value pair from a file. + Read only the value of a key-value pair from a file. - Args: - - `file_name`: The name of the file to read from. - - `key`: The key to read the value of. + Args: + - `file_name`: The name of the file to read from. + - `key`: The key to read the value of. """ _, json_exists, _, ddb_exists = utils.file_info(file_name) @@ -39,14 +38,13 @@ def partial_read(file_name: str, key: str) -> dict: return io_unsafe.partial_read(file_name, key) - def write(file_name: str, data: dict) -> None: """ - Ensures that writing only starts if there is no reading or writing in progress. + Ensures that writing only starts if there is no reading or writing in progress. - Args: - - `file_name`: The name of the file to write to. - - `data`: The data to write to the file. + Args: + - `file_name`: The name of the file to write to. + - `data`: The data to write to the file. """ dirname = os.path.dirname(f"{config.storage_directory}/{file_name}.any") @@ -56,13 +54,12 @@ def write(file_name: str, data: dict) -> None: io_unsafe.write(file_name, data) - def delete(file_name: str) -> None: """ - Ensures that deleting only starts if there is no reading or writing in progress. + Ensures that deleting only starts if there is no reading or writing in progress. - Args: - - `file_name`: The name of the file to delete. + Args: + - `file_name`: The name of the file to delete. """ json_path, json_exists, ddb_path, ddb_exists = utils.file_info(file_name) diff --git a/dictdatabase/io_unsafe.py b/dictdatabase/io_unsafe.py index 3780d0f..eb3f471 100644 --- a/dictdatabase/io_unsafe.py +++ b/dictdatabase/io_unsafe.py @@ -35,9 +35,9 @@ class PartialFileHandle: def read(db_name: str) -> dict: """ - Read the file at db_path from the configured storage directory. - Make sure the file exists. If it does not a FileNotFoundError is - raised. + Read the file at db_path from the configured storage directory. + Make sure the file exists. If it does not a FileNotFoundError is + raised. """ # Always use orjson to read the file, because it is faster return orjson.loads(io_bytes.read(db_name)) @@ -50,9 +50,9 @@ def read(db_name: str) -> dict: def try_read_bytes_using_indexer(indexer: indexing.Indexer, db_name: str, key: str) -> bytes | None: """ - Check if the key info is saved in the file's index file. - If it is and the value has not changed, return the value bytes. - Otherwise return None. + Check if the key info is saved in the file's index file. + If it is and the value has not changed, return the value bytes. + Otherwise return None. """ if (index := indexer.get(key)) is None: @@ -66,12 +66,12 @@ def try_read_bytes_using_indexer(indexer: indexing.Indexer, db_name: str, key: s def partial_read(db_name: str, key: str) -> dict | None: """ - Partially read a key from a db. - The key MUST be unique in the entire db, otherwise the behavior is undefined. - This is a lot faster than reading the entire db, because it does not parse - the entire file, but only the part part of the : pair. + Partially read a key from a db. + The key MUST be unique in the entire db, otherwise the behavior is undefined. + This is a lot faster than reading the entire db, because it does not parse + the entire file, but only the part part of the : pair. - If the key is not found, a `KeyError` is raised. + If the key is not found, a `KeyError` is raised. """ # Search for key in the index file @@ -90,7 +90,7 @@ def partial_read(db_name: str, key: str) -> dict | None: start = key_end + (1 if all_file_bytes[key_end] == byte_codes.SPACE else 0) end = utils.seek_index_through_value_bytes(all_file_bytes, start) - indent_level, indent_with = utils.detect_indentation_in_json_bytes(all_file_bytes, key_start) + indent_level, indent_with = utils.detect_indentation_in_json_bytes(all_file_bytes, key_start) value_bytes = all_file_bytes[start:end] value_hash = hashlib.sha256(value_bytes).hexdigest() @@ -106,9 +106,9 @@ def partial_read(db_name: str, key: str) -> dict | None: def serialize_data_to_json_bytes(data: dict) -> bytes: """ - Serialize the data as json bytes. Depending on the config, - this can be done with orjson or the standard json module. - Additionally config.indent is respected. + Serialize the data as json bytes. Depending on the config, + this can be done with orjson or the standard json module. + Additionally config.indent is respected. """ if config.use_orjson: option = (orjson.OPT_INDENT_2 if config.indent else 0) | orjson.OPT_SORT_KEYS @@ -120,8 +120,8 @@ def serialize_data_to_json_bytes(data: dict) -> bytes: def write(db_name: str, data: dict) -> None: """ - Write the dict db dumped as a json string - to the file of the db_path. + Write the dict db dumped as a json string + to the file of the db_path. """ data_bytes = serialize_data_to_json_bytes(data) io_bytes.write(db_name, data_bytes) @@ -138,12 +138,12 @@ def try_get_partial_file_handle_by_index( key: str, ) -> tuple[PartialFileHandle | None, bytes | None]: """ - Try to get a partial file handle by using the key entry in the index file. + Try to get a partial file handle by using the key entry in the index file. - If the data could be read from the index file, a tuple of the partial file - handle and None is returned. - If the data could not be read from the index file, a tuple of None and the file - bytes is returned, so that the file bytes can be searched for the key. + If the data could be read from the index file, a tuple of the partial file + handle and None is returned. + If the data could not be read from the index file, a tuple of None and the file + bytes is returned, so that the file bytes can be searched for the key. """ if (index := indexer.get(key)) is None: @@ -176,12 +176,12 @@ def try_get_partial_file_handle_by_index( def get_partial_file_handle(db_name: str, key: str) -> PartialFileHandle: """ - Partially read a key from a db. - The key MUST be unique in the entire db, otherwise the behavior is undefined. - This is a lot faster than reading the entire db, because it does not parse - the entire file, but only the part part of the : pair. + Partially read a key from a db. + The key MUST be unique in the entire db, otherwise the behavior is undefined. + This is a lot faster than reading the entire db, because it does not parse + the entire file, but only the part part of the : pair. - If the key is not found, a `KeyError` is raised. + If the key is not found, a `KeyError` is raised. """ # Search for key in the index file @@ -194,13 +194,13 @@ def get_partial_file_handle(db_name: str, key: str) -> PartialFileHandle: key_start, key_end = utils.find_outermost_key_in_json_bytes(all_file_bytes, key) if key_end == -1: - raise KeyError(f"Key \"{key}\" not found in db \"{db_name}\"") + raise KeyError(f'Key "{key}" not found in db "{db_name}"') # Key found, now determine the bounding byte indices of the value start = key_end + (1 if all_file_bytes[key_end] == byte_codes.SPACE else 0) end = utils.seek_index_through_value_bytes(all_file_bytes, start) - indent_level, indent_with = utils.detect_indentation_in_json_bytes(all_file_bytes, key_start) + indent_level, indent_with = utils.detect_indentation_in_json_bytes(all_file_bytes, key_start) partial_value = orjson.loads(all_file_bytes[start:end]) prefix_bytes = all_file_bytes[:start] if config.use_compression else None @@ -210,7 +210,7 @@ def get_partial_file_handle(db_name: str, key: str) -> PartialFileHandle: def partial_write(pf: PartialFileHandle) -> None: """ - Write a partial file handle to the db. + Write a partial file handle to the db. """ partial_bytes = serialize_data_to_json_bytes(pf.partial_dict.value) diff --git a/dictdatabase/locking.py b/dictdatabase/locking.py index c48a42c..069088a 100644 --- a/dictdatabase/locking.py +++ b/dictdatabase/locking.py @@ -178,6 +178,7 @@ class ReadLock(AbstractLock): A file-based read lock. Multiple threads/processes can simultaneously hold a read lock unless there's a write lock. """ + mode = "read" def _lock(self) -> None: @@ -196,8 +197,7 @@ def _lock(self) -> None: # Try to acquire lock until conditions are met or a timeout occurs while True: if not self.snapshot.any_write_locks or ( - not self.snapshot.any_has_write_locks - and self.snapshot.oldest_need(self.need_lock) + not self.snapshot.any_has_write_locks and self.snapshot.oldest_need(self.need_lock) ): self.has_lock = self.has_lock.new_with_updated_time() os_touch(self.has_lock.path) @@ -214,6 +214,7 @@ class WriteLock(AbstractLock): A file-based write lock. Only one thread/process can hold a write lock, blocking others from acquiring either read or write locks. """ + mode = "write" def _lock(self) -> None: diff --git a/dictdatabase/models.py b/dictdatabase/models.py index 405e3bc..fdc2822 100644 --- a/dictdatabase/models.py +++ b/dictdatabase/models.py @@ -61,7 +61,6 @@ def dir_where(self) -> bool: return self.dir and self.where and not self.key - def at(*path, key: str = None, where: Callable[[Any, Any], bool] = None) -> DDBMethodChooser: """ Select a file or folder to perform an operation on. @@ -89,7 +88,6 @@ def at(*path, key: str = None, where: Callable[[Any, Any], bool] = None) -> DDBM class DDBMethodChooser: - __slots__ = ("path", "key", "where", "op_type") path: str @@ -97,8 +95,8 @@ class DDBMethodChooser: where: Callable[[Any, Any], bool] op_type: OperationType - - def __init__(self, + def __init__( + self, path: tuple, key: str = None, where: Callable[[Any, Any], bool] = None, @@ -106,7 +104,7 @@ def __init__(self, # Convert path to a list of strings pc = [] for p in path: - pc += p if isinstance(p, list) else [p] + pc += p if isinstance(p, list) else [p] self.path = "/".join([str(p) for p in pc]) self.key = key self.where = where @@ -115,7 +113,6 @@ def __init__(self, # - Both key and where cannot be not None at the same time # - If key is not None, then there is no wildcard in the path. - def exists(self) -> bool: """ Efficiently checks if a database exists. If the selected path contains @@ -136,7 +133,6 @@ def exists(self) -> bool: # Key is passed and occurs is True return io_safe.partial_read(self.path, key=self.key) is not None - def create(self, data: dict | None = None, force_overwrite: bool = False) -> None: """ Create a new file with the given data as the content. If the file @@ -154,13 +150,14 @@ def create(self, data: dict | None = None, force_overwrite: bool = False) -> Non # Except if db exists and force_overwrite is False if not force_overwrite and self.exists(): - raise FileExistsError(f"Database {self.path} already exists in {config.storage_directory}. Pass force_overwrite=True to overwrite.") + raise FileExistsError( + f"Database {self.path} already exists in {config.storage_directory}. Pass force_overwrite=True to overwrite." + ) # Write db to file if data is None: data = {} io_safe.write(self.path, data) - def delete(self) -> None: """ Delete the file at the selected path. @@ -169,7 +166,6 @@ def delete(self) -> None: raise RuntimeError("DDB.at().delete() cannot be used with the where or key parameters") io_safe.delete(self.path) - def read(self, as_type: Type[T] = None) -> dict | T | None: """ Reads a file or folder depending on previous `.at(...)` selection. @@ -212,8 +208,9 @@ def type_cast(value): return type_cast(data) - - def session(self, as_type: Type[T] = None) -> SessionFileFull[T] | SessionFileKey[T] | SessionFileWhere[T] | SessionDirFull[T] | SessionDirWhere[T]: + def session( + self, as_type: Type[T] = None + ) -> SessionFileFull[T] | SessionFileKey[T] | SessionFileWhere[T] | SessionDirFull[T] | SessionDirWhere[T]: """ Opens a session to the selected file(s) or folder, depending on previous `.at(...)` selection. Inside the with block, you have exclusive access diff --git a/dictdatabase/sessions.py b/dictdatabase/sessions.py index 9be9bbe..36e1a38 100644 --- a/dictdatabase/sessions.py +++ b/dictdatabase/sessions.py @@ -9,12 +9,10 @@ JSONSerializable = TypeVar("JSONSerializable", str, int, float, bool, None, list, dict) - def type_cast(obj, as_type): return obj if as_type is None else as_type(obj) - class SessionBase: in_session: bool db_name: str @@ -44,12 +42,11 @@ def write(self): raise PermissionError("Only call write() inside a with statement.") - @contextmanager def safe_context(super, self, *, db_names_to_lock=None): """ - If an exception happens in the context, the __exit__ method of the passed super - class will be called. + If an exception happens in the context, the __exit__ method of the passed super + class will be called. """ super.__enter__() try: @@ -66,19 +63,17 @@ class will be called. raise e - ######################################################################################## #### File sessions ######################################################################################## - class SessionFileFull(SessionBase, Generic[T]): """ - Context manager for read-write access to a full file. + Context manager for read-write access to a full file. - Efficiency: - Reads and writes the entire file. + Efficiency: + Reads and writes the entire file. """ def __enter__(self) -> Tuple[SessionFileFull, JSONSerializable | T]: @@ -91,15 +86,14 @@ def write(self): io_unsafe.write(self.db_name, self.data_handle) - class SessionFileKey(SessionBase, Generic[T]): """ - Context manager for read-write access to a single key-value item in a file. + Context manager for read-write access to a single key-value item in a file. - Efficiency: - Uses partial reading, which allows only reading the bytes of the key-value item. - When writing, only the bytes of the key-value and the bytes of the file after - the key-value are written. + Efficiency: + Uses partial reading, which allows only reading the bytes of the key-value item. + When writing, only the bytes of the key-value and the bytes of the file after + the key-value are written. """ def __init__(self, db_name: str, key: str, as_type: T): @@ -117,16 +111,16 @@ def write(self): io_unsafe.partial_write(self.partial_handle) - class SessionFileWhere(SessionBase, Generic[T]): """ - Context manager for read-write access to selection of key-value items in a file. - The where callable is called with the key and value of each item in the file. + Context manager for read-write access to selection of key-value items in a file. + The where callable is called with the key and value of each item in the file. - Efficiency: - Reads and writes the entire file, so it is not more efficient than - SessionFileFull. + Efficiency: + Reads and writes the entire file, so it is not more efficient than + SessionFileFull. """ + def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T): super().__init__(db_name, as_type) self.where = where @@ -145,22 +139,21 @@ def write(self): io_unsafe.write(self.db_name, self.original_data) - ######################################################################################## #### File sessions ######################################################################################## - class SessionDirFull(SessionBase, Generic[T]): """ - Context manager for read-write access to all files in a directory. - They are provided as a dict of {str(file_name): dict(file_content)}, where the - file name does not contain the directory name nor the file extension. + Context manager for read-write access to all files in a directory. + They are provided as a dict of {str(file_name): dict(file_content)}, where the + file name does not contain the directory name nor the file extension. - Efficiency: - Fully reads and writes all files. + Efficiency: + Fully reads and writes all files. """ + def __init__(self, db_name: str, as_type: T): super().__init__(utils.find_all(db_name), as_type) @@ -175,15 +168,15 @@ def write(self): io_unsafe.write(name, self.data_handle[name.split("/")[-1]]) - class SessionDirWhere(SessionBase, Generic[T]): """ - Context manager for read-write access to selection of files in a directory. - The where callable is called with the file name and parsed content of each file. + Context manager for read-write access to selection of files in a directory. + The where callable is called with the file name and parsed content of each file. - Efficiency: - Fully reads all files, but only writes the selected files. + Efficiency: + Fully reads all files, but only writes the selected files. """ + def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T): super().__init__(utils.find_all(db_name), as_type) self.where = where diff --git a/dictdatabase/utils.py b/dictdatabase/utils.py index db150be..49418ae 100644 --- a/dictdatabase/utils.py +++ b/dictdatabase/utils.py @@ -9,14 +9,14 @@ def file_info(db_name: str) -> Tuple[str, bool, str, bool]: """ - Returns a tuple of four elements, the first and third being the paths to the - JSON and DDB files, and the second and third being booleans indicating whether - those files exist: + Returns a tuple of four elements, the first and third being the paths to the + JSON and DDB files, and the second and third being booleans indicating whether + those files exist: - >>> (json_path, json_exists, ddb_path, ddb_exists) + >>> (json_path, json_exists, ddb_path, ddb_exists) - Args: - - `db_name`: The name of the database + Args: + - `db_name`: The name of the database """ base = f"{config.storage_directory}/{db_name}" j, d = f"{base}.json", f"{base}.ddb" @@ -171,7 +171,7 @@ def find_outermost_key_in_json_bytes(json_bytes: bytes, key: str) -> Tuple[int, # TODO: Very strict. the key must have a colon directly after it # For example {"a": 1} will work, but {"a" : 1} will not work! - key = f"\"{key}\":".encode() + key = f'"{key}":'.encode() if (curr_i := json_bytes.find(key, 0)) == -1: return (-1, -1) diff --git a/profiler.py b/profiler.py index aa9943e..8e788ef 100644 --- a/profiler.py +++ b/profiler.py @@ -1,10 +1,10 @@ from distutils.command.config import config -import dictdatabase as DDB -from dictdatabase import io_unsafe + from path_dict import PathDict from pyinstrument import profiler - +import dictdatabase as DDB +from dictdatabase import io_unsafe DDB.config.storage_directory = "./test_db/production_database" DDB.config.use_orjson = True @@ -13,14 +13,14 @@ p = profiler.Profiler(interval=0.0001) with p: - # fM44 is small - # a2lU has many annotations - # DDB.at("tasks", key="fM44").read(key="fM44", as_type=PathDict) - for _ in range(10): - with DDB.at("tasks", key="a2lU").session(as_type=PathDict) as (session, task): - task["jay"] = lambda x: (x or 0) + 1 - session.write() - # DDB.at("tasks_as_dir/*").read() + # fM44 is small + # a2lU has many annotations + # DDB.at("tasks", key="fM44").read(key="fM44", as_type=PathDict) + for _ in range(10): + with DDB.at("tasks", key="a2lU").session(as_type=PathDict) as (session, task): + task["jay"] = lambda x: (x or 0) + 1 + session.write() + # DDB.at("tasks_as_dir/*").read() p.open_in_browser(timeline=False) diff --git a/pyproject.toml b/pyproject.toml index 1483bd7..3fb9d81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ build-backend = "poetry.core.masonry.api" [tool.ruff] show-fixes = true +line-length = 120 select = [ "ANN", # annotations "B", # bugbear @@ -54,3 +55,7 @@ ignore = [ "W191", # indentation contains tabs "E741", # ambiguous variable name ] + +[tool.ruff.format] +indent-style = "tab" +quote-style = "double" diff --git a/scenario_comparison.py b/scenario_comparison.py index 84b0b30..66ade1b 100644 --- a/scenario_comparison.py +++ b/scenario_comparison.py @@ -1,9 +1,10 @@ - -import dictdatabase as DDB -from pyinstrument import profiler -from pathlib import Path import random import time +from pathlib import Path + +from pyinstrument import profiler + +import dictdatabase as DDB DDB.config.storage_directory = ".ddb_scenario_comparison" Path(DDB.config.storage_directory).mkdir(exist_ok=True) @@ -12,16 +13,16 @@ # Create a database with 10_000 entries all_users = {} for i in range(10_000): - print(i) - user = { - "id": "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=8)), - "name": "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=5)), - "surname": "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=20)), - "description": "".join(random.choices("abcdefghij\"klmnopqrst😁uvwxyz\\ ", k=5000)), - "age": random.randint(0, 100), - } - all_users[user["id"]] = user - DDB.at("users_dir", user["id"]).create(user) + print(i) + user = { + "id": "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=8)), + "name": "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=5)), + "surname": "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=20)), + "description": "".join(random.choices('abcdefghij"klmnopqrst😁uvwxyz\\ ', k=5000)), + "age": random.randint(0, 100), + } + all_users[user["id"]] = user + DDB.at("users_dir", user["id"]).create(user) DDB.at("users").create(all_users) @@ -32,7 +33,7 @@ # 06.11.22: 2695ms t1 = time.monotonic() with profiler.Profiler() as p: - DDB.at("users_dir/*").read() + DDB.at("users_dir/*").read() p.open_in_browser() print("Read all users from directory:", time.monotonic() - t1) diff --git a/scene_random_writes.py b/scene_random_writes.py index 1f44f43..08763a4 100644 --- a/scene_random_writes.py +++ b/scene_random_writes.py @@ -1,8 +1,8 @@ - -import dictdatabase as DDB import random + from pyinstrument.profiler import Profiler +import dictdatabase as DDB user_count = 100_000 @@ -25,10 +25,10 @@ p = Profiler(interval=0.0001) p.start() for it in range(500): - print(it) - user_id = str(random.randint(user_count - 100, user_count - 1)) - with DDB.at("users", key=user_id).session() as (session, user): - user["age"] += 1 - session.write() + print(it) + user_id = str(random.randint(user_count - 100, user_count - 1)) + with DDB.at("users", key=user_id).session() as (session, user): + user["age"] += 1 + session.write() p.stop() p.open_in_browser(timeline=False) diff --git a/test_key_finder.py b/test_key_finder.py index a3c7309..115845a 100644 --- a/test_key_finder.py +++ b/test_key_finder.py @@ -3,17 +3,16 @@ from dictdatabase import utils test_dict = { - "b": 2, - "c": { - "a": 1, - "b": 2, - }, - "d": { - "a": 1, - "b": 2, - }, - "a": 1, - + "b": 2, + "c": { + "a": 1, + "b": 2, + }, + "d": { + "a": 1, + "b": 2, + }, + "a": 1, } json_str = json.dumps(test_dict, indent=2, sort_keys=False) @@ -23,8 +22,7 @@ print("lel") print(index) -print(json_bytes[index[0]:index[1]]) - +print(json_bytes[index[0] : index[1]]) print(b"00111000".find(b"111", 0, 20)) diff --git a/tests/benchmark/locking.py b/tests/benchmark/locking.py index b8dfd98..325a12d 100644 --- a/tests/benchmark/locking.py +++ b/tests/benchmark/locking.py @@ -1,8 +1,10 @@ -import dictdatabase as DDB -from pyinstrument import profiler +import shutil from pathlib import Path + +from pyinstrument import profiler + +import dictdatabase as DDB from dictdatabase import locking -import shutil DDB.config.storage_directory = "./.benchmark_locking" path = Path(DDB.config.storage_directory) @@ -12,20 +14,20 @@ # 05.11.22: 4520ms # 25.11.22: 4156ms with profiler.Profiler() as p: - for _ in range(25_000): - l = locking.ReadLock("db") - l._lock() - l._unlock() + for _ in range(25_000): + l = locking.ReadLock("db") + l._lock() + l._unlock() p.open_in_browser() # 05.11.22: 4884ms # 25.11.22: 4159ms with profiler.Profiler() as p: - for _ in range(25_000): - l = locking.WriteLock("db") - l._lock() - l._unlock() + for _ in range(25_000): + l = locking.WriteLock("db") + l._lock() + l._unlock() p.open_in_browser() diff --git a/tests/benchmark/parallel_appends.py b/tests/benchmark/parallel_appends.py index 7515c18..025ebc8 100644 --- a/tests/benchmark/parallel_appends.py +++ b/tests/benchmark/parallel_appends.py @@ -1,13 +1,14 @@ -from calendar import c import json -import dictdatabase as DDB -from multiprocessing import Pool +import os import shutil import time -import os +from calendar import c +from multiprocessing import Pool + from pyinstrument import Profiler +from utils import db_job, make_table, print_and_assert_results -from utils import print_and_assert_results, db_job, make_table +import dictdatabase as DDB def proc_job(id, n): @@ -17,18 +18,20 @@ def proc_job(id, n): t1 = time.monotonic_ns() with DDB.at("append_here").session() as (session, db): if len(db) == 0: - db += [{ - "counter": 0, - "firstname": "John", - "lastname": "Doe", - "age": 42, - "address": "1234 Main St", - "city": "Anytown", - "state": "CA", - "zip": "12345", - "phone": "123-456-7890", - "interests": ["Python", "Databases", "DDB", "DDB-CLI", "DDB-Web", "Google"], - }] * 50000 + db += [ + { + "counter": 0, + "firstname": "John", + "lastname": "Doe", + "age": 42, + "address": "1234 Main St", + "city": "Anytown", + "state": "CA", + "zip": "12345", + "phone": "123-456-7890", + "interests": ["Python", "Databases", "DDB", "DDB-CLI", "DDB-Web", "Google"], + } + ] * 50000 else: db.append({**db[-1], "counter": db[-1]["counter"] + 1}) session.write() @@ -48,9 +51,6 @@ def proc_read_job(id, n): print(f"{(time.monotonic_ns() - t1) / 1e6:.2f} ms {vis}") - - - if __name__ == "__main__": proc_count = 2 per_proc = 100 @@ -61,8 +61,20 @@ def proc_read_job(id, n): t1 = time.monotonic() pool = Pool(processes=proc_count * 2) for i in range(proc_count): - pool.apply_async(proc_job, args=(i, per_proc,)) - pool.apply_async(proc_read_job, args=(i, per_proc,)) + pool.apply_async( + proc_job, + args=( + i, + per_proc, + ), + ) + pool.apply_async( + proc_read_job, + args=( + i, + per_proc, + ), + ) pool.close() pool.join() print(f"⏱️ {time.monotonic() - t1} seconds") diff --git a/tests/benchmark/run_async.py b/tests/benchmark/run_async.py index 4c26bfa..ee7d0ff 100644 --- a/tests/benchmark/run_async.py +++ b/tests/benchmark/run_async.py @@ -1,11 +1,12 @@ -import dictdatabase as DDB import asyncio +import os import shutil import time -import os from utils import incrementor, print_and_assert_results +import dictdatabase as DDB + async def thread_job(i, n, file_count): DDB.locking.SLEEP_TIMEOUT = 0.001 @@ -20,7 +21,6 @@ async def threaded_stress(file_count=2, thread_count=10, per_thread=500): # Create tasks for concurrent execution tasks = [(incrementor, (i, per_thread, file_count)) for i in range(thread_count)] - # Execute process pool running incrementor as the target task t1 = time.monotonic() await asyncio.gather(*[thread_job(i, per_thread, file_count) for i in range(thread_count)]) diff --git a/tests/benchmark/run_big_file.py b/tests/benchmark/run_big_file.py index 3ae1919..431f900 100644 --- a/tests/benchmark/run_big_file.py +++ b/tests/benchmark/run_big_file.py @@ -1,85 +1,75 @@ -import dictdatabase as DDB import random import time +import dictdatabase as DDB + def make_random_posts(count): - posts = {} - for _ in range(count): - id = str(random.randint(0, 999_999_999)) - title_length = random.randint(10, 100) - content_length = random.randint(200, 500) - posts[id] = { - 'id': id, - 'title': "".join(random.choices(" abcdefghijklmnopqrstuvwxyz,.", k=title_length)), - 'content': "".join(random.choices(" abcdefghijklmnopqrstuvwxyz,.", k=content_length)), - } - return posts - + posts = {} + for _ in range(count): + id = str(random.randint(0, 999_999_999)) + title_length = random.randint(10, 100) + content_length = random.randint(200, 500) + posts[id] = { + "id": id, + "title": "".join(random.choices(" abcdefghijklmnopqrstuvwxyz,.", k=title_length)), + "content": "".join(random.choices(" abcdefghijklmnopqrstuvwxyz,.", k=content_length)), + } + return posts def make_users(count): - all_users = {} - for i in range(count): - all_users[str(i)] = { - "id": str(i), - "name": "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=5)), - "surname": "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=20)), - "age": random.randint(20, 80), - "posts": make_random_posts(random.randint(200, 300)), - } - return all_users - - + all_users = {} + for i in range(count): + all_users[str(i)] = { + "id": str(i), + "name": "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=5)), + "surname": "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=20)), + "age": random.randint(20, 80), + "posts": make_random_posts(random.randint(200, 300)), + } + return all_users def read_specific_users(): - accessed_users = sorted([str(i * 100) for i in range(100)], key=lambda x: random.random()) - t1 = time.monotonic() - for user_id in accessed_users: - print(f"Accessing user {user_id}") - u = DDB.at("big_users", key=user_id).read() - print(f"User {user_id} has {len(u['posts'])} posts and is {u['age']} years old") - t2 = time.monotonic() - print(f"Time taken: {(t2 - t1) * 1000}ms") - + accessed_users = sorted([str(i * 100) for i in range(100)], key=lambda x: random.random()) + t1 = time.monotonic() + for user_id in accessed_users: + print(f"Accessing user {user_id}") + u = DDB.at("big_users", key=user_id).read() + print(f"User {user_id} has {len(u['posts'])} posts and is {u['age']} years old") + t2 = time.monotonic() + print(f"Time taken: {(t2 - t1) * 1000}ms") def write_specific_users(): - accessed_users = sorted([str(i * 100) for i in range(100)], key=lambda x: random.random()) - t1 = time.monotonic() - for user_id in accessed_users: - print(f"Accessing user {user_id}") - - with DDB.at("big_users", key=user_id).session() as (session, user): - user["surname"] = "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=random.randint(3, 50))) - session.write() - t2 = time.monotonic() - print(f"Time taken: {(t2 - t1) * 1000}ms") - + accessed_users = sorted([str(i * 100) for i in range(100)], key=lambda x: random.random()) + t1 = time.monotonic() + for user_id in accessed_users: + print(f"Accessing user {user_id}") + with DDB.at("big_users", key=user_id).session() as (session, user): + user["surname"] = "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=random.randint(3, 50))) + session.write() + t2 = time.monotonic() + print(f"Time taken: {(t2 - t1) * 1000}ms") def random_access_users(write_read_ratio=0.1, count=500): - accessed_users = [str(i * 100) for i in [random.randint(0, 99) for _ in range(count)]] - t1 = time.monotonic() - for user_id in accessed_users: - - if random.random() < write_read_ratio: - with DDB.at("big_users", key=user_id).session() as (session, user): - user["surname"] = "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=random.randint(3, 50))) - session.write() - print(f"Accessed user {user_id} for writing") - else: - u = DDB.at("big_users", key=user_id).read() - print(f"User {user_id} has {len(u['posts'])} posts and is {u['age']} years old") - - t2 = time.monotonic() - print(f"Time taken: {t2 - t1}s") - - - - + accessed_users = [str(i * 100) for i in [random.randint(0, 99) for _ in range(count)]] + t1 = time.monotonic() + for user_id in accessed_users: + if random.random() < write_read_ratio: + with DDB.at("big_users", key=user_id).session() as (session, user): + user["surname"] = "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=random.randint(3, 50))) + session.write() + print(f"Accessed user {user_id} for writing") + else: + u = DDB.at("big_users", key=user_id).read() + print(f"User {user_id} has {len(u['posts'])} posts and is {u['age']} years old") + + t2 = time.monotonic() + print(f"Time taken: {t2 - t1}s") # DDB.at("big_users").create(make_users(20_000), force_overwrite=True) # 2500MB diff --git a/tests/benchmark/run_parallel.py b/tests/benchmark/run_parallel.py index bb8fec3..26e5909 100644 --- a/tests/benchmark/run_parallel.py +++ b/tests/benchmark/run_parallel.py @@ -1,14 +1,15 @@ -from calendar import c import json -import dictdatabase as DDB -from multiprocessing import Pool +import os +import random import shutil import time -import os +from calendar import c from dataclasses import dataclass -import random +from multiprocessing import Pool + from path_dict import PathDict +import dictdatabase as DDB DDB.config.storage_directory = ".ddb_bench_multi" @@ -25,7 +26,9 @@ def wrapper(*args, **kwargs): function(*args, **kwargs) t2 = time.monotonic() print(f"⏱️ {iterations / (t2 - t1):.1f} op/s for {f_name} ({(t2 - t1):.1f} seconds)") + return wrapper + return decorator @@ -53,7 +56,6 @@ def sequential_partial_write_small_file(name): session.write() - @dataclass class Scenario: files: int = 1 @@ -82,9 +84,6 @@ def print_and_assert_results(scenario: Scenario, t): assert db["counter"]["counter"] == scenario.ops * scenario.writers - - - def process_job(mode, scenario, cfg): DDB.config = cfg DDB.locking.SLEEP_TIMEOUT = 0.001 @@ -128,42 +127,27 @@ def parallel_stressor(scenario: Scenario): scenarios = [ - - Scenario(readers=1, ops=6000), Scenario(readers=2, ops=6000), Scenario(readers=4, ops=6000), Scenario(readers=8, ops=3000), - - Scenario(writers=1, ops=6000), Scenario(writers=2, ops=1000), Scenario(writers=4, ops=800), Scenario(writers=8, ops=200), - - Scenario(readers=20, writers=20, ops=30), - - Scenario(readers=8, ops=1500), Scenario(readers=8, ops=1500, use_compression=True), Scenario(readers=8, ops=1500, big_file=True), - - Scenario(readers=8, writers=1, ops=200), Scenario(readers=8, writers=1, ops=25, big_file=True), - - Scenario(readers=1, writers=8, ops=200), Scenario(readers=1, writers=8, ops=10, big_file=True), - - Scenario(readers=8, writers=8, ops=100), Scenario(readers=8, writers=8, ops=8, big_file=True), ] if __name__ == "__main__": - # print("✨ Simple sequential benchmarks") # sequential_full_read_small_file() # sequential_partial_read_small_file() diff --git a/tests/benchmark/run_parallel_multi.py b/tests/benchmark/run_parallel_multi.py index 0df08dc..35ac88c 100644 --- a/tests/benchmark/run_parallel_multi.py +++ b/tests/benchmark/run_parallel_multi.py @@ -1,14 +1,15 @@ -from calendar import c import json -import dictdatabase as DDB -from multiprocessing import Pool -import shutil -import time import os -from pyinstrument import Profiler +import shutil import threading +import time +from calendar import c +from multiprocessing import Pool +from pyinstrument import Profiler from utils import print_and_assert_results + +import dictdatabase as DDB from dictdatabase.configuration import Confuguration @@ -23,8 +24,6 @@ def proc_job(n, cfg): session.write() - - def parallel_stressor(file_count): # Create Tables for t in range(11): @@ -44,12 +43,6 @@ def parallel_stressor(file_count): print(r.get()) - - - - - - if __name__ == "__main__": DDB.config.storage_directory = ".ddb_bench_parallel" try: diff --git a/tests/benchmark/run_threaded.py b/tests/benchmark/run_threaded.py index 12f305a..ec3abec 100644 --- a/tests/benchmark/run_threaded.py +++ b/tests/benchmark/run_threaded.py @@ -1,11 +1,12 @@ -import dictdatabase as DDB -import super_py as sp +import json +import os import shutil import time -import os -import json -from utils import print_and_assert_results, db_job +import super_py as sp +from utils import db_job, print_and_assert_results + +import dictdatabase as DDB def threaded_stressor(file_count, readers, writers, operations_per_thread, big_file, compression): @@ -28,13 +29,14 @@ def threaded_stressor(file_count, readers, writers, operations_per_thread, big_f print_and_assert_results(readers, writers, operations_per_process, file_count, big_file, compression, t1, t2) - if __name__ == "__main__": DDB.config.storage_directory = ".ddb_bench_threaded" operations_per_process = 4 for file_count, readers, writers in [(1, 4, 4), (1, 8, 1), (1, 1, 8), (4, 8, 8)]: print("") - print(f"✨ Scenario: {file_count} files, {readers} readers, {writers} writers, {operations_per_process} operations per process") + print( + f"✨ Scenario: {file_count} files, {readers} readers, {writers} writers, {operations_per_process} operations per process" + ) for big_file, compression in [(False, False), (False, True), (True, False), (True, True)]: try: shutil.rmtree(".ddb_bench_threaded", ignore_errors=True) diff --git a/tests/benchmark/sequential_appends.py b/tests/benchmark/sequential_appends.py index 518f4fb..bc91ea8 100644 --- a/tests/benchmark/sequential_appends.py +++ b/tests/benchmark/sequential_appends.py @@ -1,26 +1,34 @@ -from calendar import c import json -import dictdatabase as DDB -from multiprocessing import Pool +import os import shutil import time -import os +from calendar import c +from multiprocessing import Pool + from pyinstrument import Profiler +import dictdatabase as DDB + def seq_job(n): - DDB.at("db").create([{ - "counter": 0, - "firstname": "John", - "lastname": "Doe", - "age": 42, - "address": "1234 Main St", - "city": "Anytown", - "state": "CA", - "zip": "12345", - "phone": "123-456-7890", - "interests": ["Python", "Databases", "DDB", "DDB-CLI", "DDB-Web", "Google"], - }] * 50000, force_overwrite=True) + DDB.at("db").create( + [ + { + "counter": 0, + "firstname": "John", + "lastname": "Doe", + "age": 42, + "address": "1234 Main St", + "city": "Anytown", + "state": "CA", + "zip": "12345", + "phone": "123-456-7890", + "interests": ["Python", "Databases", "DDB", "DDB-CLI", "DDB-Web", "Google"], + } + ] + * 50000, + force_overwrite=True, + ) for _ in range(n): t1 = time.monotonic_ns() with DDB.at("db").session() as (session, db): diff --git a/tests/benchmark/sqlite/test.py b/tests/benchmark/sqlite/test.py index 24de9ee..17f0f0b 100644 --- a/tests/benchmark/sqlite/test.py +++ b/tests/benchmark/sqlite/test.py @@ -1,16 +1,14 @@ -import time import os import sqlite3 -import os -import super_py as sp +import time +import super_py as sp def teardown(): os.remove("test.db") - @sp.test(teardown=teardown) def parallel_stress(tables=4, processes=16, per_process=128): # Create the database with all tables diff --git a/tests/benchmark/sqlite/test_parallel_runner.py b/tests/benchmark/sqlite/test_parallel_runner.py index 65a5aa7..6497252 100644 --- a/tests/benchmark/sqlite/test_parallel_runner.py +++ b/tests/benchmark/sqlite/test_parallel_runner.py @@ -1,7 +1,6 @@ -from multiprocessing import Pool -import sys import sqlite3 - +import sys +from multiprocessing import Pool def incr_db(n, tables): @@ -15,7 +14,6 @@ def incr_db(n, tables): return True - if __name__ == "__main__": tables = int(sys.argv[1]) processes = int(sys.argv[2]) @@ -23,6 +21,12 @@ def incr_db(n, tables): pool = Pool(processes=processes) for _ in range(processes): - pool.apply_async(incr_db, args=(per_process, tables,)) + pool.apply_async( + incr_db, + args=( + per_process, + tables, + ), + ) pool.close() pool.join() diff --git a/tests/benchmark/utils.py b/tests/benchmark/utils.py index 202742b..29b004b 100644 --- a/tests/benchmark/utils.py +++ b/tests/benchmark/utils.py @@ -1,8 +1,10 @@ -import dictdatabase as DDB -from path_dict import pd import random import time +from path_dict import pd + +import dictdatabase as DDB + def make_table(recursion_depth=3, keys_per_level=50): d = {"key1": "val1", "key2": 2, "key3": [1, "2", [3, 3]]} diff --git a/tests/conftest.py b/tests/conftest.py index ee6eae8..6d1bddb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,34 +1,32 @@ from pathlib import Path -import dictdatabase as DDB + import pytest +import dictdatabase as DDB + @pytest.fixture(autouse=True) def isolate_database_files(tmp_path: Path): DDB.config.storage_directory = str(tmp_path) - @pytest.fixture(scope="function") def name_of_test(request): return request.function.__name__ - @pytest.fixture(params=[True, False]) def use_compression(request): DDB.config.use_compression = request.param return request.param - @pytest.fixture(params=[True, False]) def use_orjson(request): DDB.config.use_orjson = request.param return request.param - @pytest.fixture(params=[None, 0, 2, "\t"]) def indent(request): DDB.config.indent = request.param diff --git a/tests/system_checks/test_clocks.py b/tests/system_checks/test_clocks.py index 85a25d1..5aa3814 100644 --- a/tests/system_checks/test_clocks.py +++ b/tests/system_checks/test_clocks.py @@ -4,38 +4,38 @@ def print_clocks(label: str) -> None: - print(f"--- {label} ---") - print("time_ns() :", time.time_ns()) - print("monotonic_ns() :", time.monotonic_ns()) - print("perf_counter_ns():", time.perf_counter_ns()) - print("\n") + print(f"--- {label} ---") + print("time_ns() :", time.time_ns()) + print("monotonic_ns() :", time.monotonic_ns()) + print("perf_counter_ns():", time.perf_counter_ns()) + print("\n") def thread_function(thread_name: str) -> None: - print_clocks(f"Thread-{thread_name}") + print_clocks(f"Thread-{thread_name}") def process_function(process_name: str) -> None: - print_clocks(f"Process-{process_name}") + print_clocks(f"Process-{process_name}") if __name__ == "__main__": - print_clocks("Main Thread") + print_clocks("Main Thread") - threads = [] - for i in range(3): - thread = threading.Thread(target=thread_function, args=(i,)) - thread.start() - threads.append(thread) + threads = [] + for i in range(3): + thread = threading.Thread(target=thread_function, args=(i,)) + thread.start() + threads.append(thread) - for thread in threads: - thread.join() + for thread in threads: + thread.join() - processes = [] - for i in range(3): - process = multiprocessing.Process(target=process_function, args=(i,)) - process.start() - processes.append(process) + processes = [] + for i in range(3): + process = multiprocessing.Process(target=process_function, args=(i,)) + process.start() + processes.append(process) - for process in processes: - process.join() + for process in processes: + process.join() diff --git a/tests/system_checks/test_monotonic_over_threads.py b/tests/system_checks/test_monotonic_over_threads.py index 7bcd5b3..3d6c35f 100644 --- a/tests/system_checks/test_monotonic_over_threads.py +++ b/tests/system_checks/test_monotonic_over_threads.py @@ -7,12 +7,12 @@ # Define the clocks to test clocks = { - "time ": time.time, - "time_ns ": time.time_ns, - "monotonic ": time.monotonic, - "monotonic_ns ": time.monotonic_ns, - "perf_counter ": time.perf_counter, - "perf_counter_ns": time.perf_counter_ns, + "time ": time.time, + "time_ns ": time.time_ns, + "monotonic ": time.monotonic, + "monotonic_ns ": time.monotonic_ns, + "perf_counter ": time.perf_counter, + "perf_counter_ns": time.perf_counter_ns, } # Queue to store timestamps in order @@ -20,44 +20,50 @@ def capture_time(i, clock_func: callable) -> None: - # Capture time using the given clock function and put it in the queue - for _ in range(1000): - # print(f"Thread {i} capturing time") - timestamps.put(clock_func()) + # Capture time using the given clock function and put it in the queue + for _ in range(1000): + # print(f"Thread {i} capturing time") + timestamps.put(clock_func()) def check_monotonicity_for_clock(clock_name: str, clock_func: callable) -> None: - # Clear the queue for the next clock - while not timestamps.empty(): - timestamps.get() + # Clear the queue for the next clock + while not timestamps.empty(): + timestamps.get() - # Create and start threads - threads = [] - for i in range(NUM_THREADS): - thread = threading.Thread(target=capture_time, args=(i, clock_func,)) - thread.start() - threads.append(thread) + # Create and start threads + threads = [] + for i in range(NUM_THREADS): + thread = threading.Thread( + target=capture_time, + args=( + i, + clock_func, + ), + ) + thread.start() + threads.append(thread) - # Wait for all threads to complete - for thread in threads: - thread.join() + # Wait for all threads to complete + for thread in threads: + thread.join() - # Extract timestamps from the queue - captured_times = [] - while not timestamps.empty(): - captured_times.append(timestamps.get()) + # Extract timestamps from the queue + captured_times = [] + while not timestamps.empty(): + captured_times.append(timestamps.get()) - # Check if the clock is monotonic - is_monotonic = all(captured_times[i] <= captured_times[i+1] for i in range(len(captured_times)-1)) + # Check if the clock is monotonic + is_monotonic = all(captured_times[i] <= captured_times[i + 1] for i in range(len(captured_times) - 1)) - if is_monotonic: - print(f"Clock: {clock_name} is monotonic over {NUM_THREADS} threads ✅") - else: - print(f"Clock: {clock_name} is not monotonic over {NUM_THREADS} threads ❌") - print("-" * 40) + if is_monotonic: + print(f"Clock: {clock_name} is monotonic over {NUM_THREADS} threads ✅") + else: + print(f"Clock: {clock_name} is not monotonic over {NUM_THREADS} threads ❌") + print("-" * 40) if __name__ == "__main__": - # Check monotonicity for each clock - for clock_name, clock_func in clocks.items(): - check_monotonicity_for_clock(clock_name, clock_func) + # Check monotonicity for each clock + for clock_name, clock_func in clocks.items(): + check_monotonicity_for_clock(clock_name, clock_func) diff --git a/tests/system_checks/test_tick_rate.py b/tests/system_checks/test_tick_rate.py index c85ed94..e2834e2 100644 --- a/tests/system_checks/test_tick_rate.py +++ b/tests/system_checks/test_tick_rate.py @@ -2,31 +2,31 @@ def get_tick_rate(clock_func: callable) -> float: - start_time = time.time() - measurements = [clock_func() for _ in range(2_000_000)] - end_time = time.time() + start_time = time.time() + measurements = [clock_func() for _ in range(2_000_000)] + end_time = time.time() - ticks = 0 - prev_value = measurements[0] - for current_value in measurements[1:]: - if current_value < prev_value: - raise RuntimeError("Clock function is not monotonic") - if current_value != prev_value: - ticks += 1 - prev_value = current_value + ticks = 0 + prev_value = measurements[0] + for current_value in measurements[1:]: + if current_value < prev_value: + raise RuntimeError("Clock function is not monotonic") + if current_value != prev_value: + ticks += 1 + prev_value = current_value - return ticks / (end_time - start_time) # ticks per second + return ticks / (end_time - start_time) # ticks per second if __name__ == "__main__": - clock_funcs = { - "time ": time.time, - "time_ns ": time.time_ns, - "monotonic ": time.monotonic, - "monotonic_ns ": time.monotonic_ns, - "perf_counter ": time.perf_counter, - "perf_counter_ns": time.perf_counter_ns, - } + clock_funcs = { + "time ": time.time, + "time_ns ": time.time_ns, + "monotonic ": time.monotonic, + "monotonic_ns ": time.monotonic_ns, + "perf_counter ": time.perf_counter, + "perf_counter_ns": time.perf_counter_ns, + } - for name, func in clock_funcs.items(): - print(f"Tick rate for {name}: {get_tick_rate(func) / 1_000_000.0:.3f}M ticks/second") + for name, func in clock_funcs.items(): + print(f"Tick rate for {name}: {get_tick_rate(func) / 1_000_000.0:.3f}M ticks/second") diff --git a/tests/test_at.py b/tests/test_at.py index f0b9882..ff7c44d 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -2,15 +2,15 @@ def test_at(): - assert at("x").path == "x" - assert at("x", "y", "z").path == 'x/y/z' - assert at(["x", "y", "z"]).path == 'x/y/z' - assert at("x", ["y", "z"]).path == 'x/y/z' - assert at(["x", "y"], "z").path == 'x/y/z' - assert at(["x"], "y", "z").path == 'x/y/z' - assert at("x", ["y"], "z").path == 'x/y/z' - assert at("x", "y", ["z"]).path == 'x/y/z' - assert at("x", ["y"], ["z"]).path == 'x/y/z' - assert at(["x"], "y", ["z"]).path == 'x/y/z' - assert at(["x"], ["y"], "z").path == 'x/y/z' - assert at(["x"], ["y"], ["z"]).path == 'x/y/z' + assert at("x").path == "x" + assert at("x", "y", "z").path == "x/y/z" + assert at(["x", "y", "z"]).path == "x/y/z" + assert at("x", ["y", "z"]).path == "x/y/z" + assert at(["x", "y"], "z").path == "x/y/z" + assert at(["x"], "y", "z").path == "x/y/z" + assert at("x", ["y"], "z").path == "x/y/z" + assert at("x", "y", ["z"]).path == "x/y/z" + assert at("x", ["y"], ["z"]).path == "x/y/z" + assert at(["x"], "y", ["z"]).path == "x/y/z" + assert at(["x"], ["y"], "z").path == "x/y/z" + assert at(["x"], ["y"], ["z"]).path == "x/y/z" diff --git a/tests/test_create.py b/tests/test_create.py index dc01707..8f84e67 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,8 +1,9 @@ -import dictdatabase as DDB -from path_dict import pd -import pytest import json +import pytest +from path_dict import pd + +import dictdatabase as DDB from tests.utils import make_complex_nested_random_dict diff --git a/tests/test_delete.py b/tests/test_delete.py index f85a0ed..9790c0f 100644 --- a/tests/test_delete.py +++ b/tests/test_delete.py @@ -1,6 +1,7 @@ -import dictdatabase as DDB import pytest +import dictdatabase as DDB + def test_delete(use_compression, use_orjson, indent): DDB.at("test_delete").create({"a": 1}, force_overwrite=True) @@ -15,6 +16,5 @@ def test_delete(use_compression, use_orjson, indent): DDB.at("test_delete", key="any").delete() - def test_delete_nonexistent(use_compression, use_orjson, indent): DDB.at("test_delete_nonexistent").delete() diff --git a/tests/test_excepts.py b/tests/test_excepts.py index 4029461..28cff0b 100644 --- a/tests/test_excepts.py +++ b/tests/test_excepts.py @@ -1,7 +1,8 @@ -import dictdatabase as DDB -from dictdatabase import utils, io_bytes -from path_dict import pd import pytest +from path_dict import pd + +import dictdatabase as DDB +from dictdatabase import io_bytes, utils def test_except_during_open_session(use_compression, use_orjson, indent): @@ -13,7 +14,6 @@ def test_except_during_open_session(use_compression, use_orjson, indent): raise RuntimeError("Any Exception") - def test_except_on_save_unserializable(use_compression, use_orjson, indent): name = "test_except_on_save_unserializable" with pytest.raises(TypeError): @@ -56,7 +56,6 @@ def test_wildcard_and_subkey_except(use_compression, use_orjson, indent): DDB.at("test_wildcard_and_subkey_except/*", key="key").read() - def test_utils_invalid_json_except(): with pytest.raises(TypeError): utils.seek_index_through_value_bytes(b"{This is not { JSON", 0) diff --git a/tests/test_exists.py b/tests/test_exists.py index 07b8aa3..f887520 100644 --- a/tests/test_exists.py +++ b/tests/test_exists.py @@ -1,6 +1,7 @@ -import dictdatabase as DDB import pytest +import dictdatabase as DDB + def test_exists(use_compression, use_orjson, indent): DDB.at("test_exists").create({"a": 1}, force_overwrite=True) diff --git a/tests/test_indentation.py b/tests/test_indentation.py index 6c34ce9..f2fa4f3 100644 --- a/tests/test_indentation.py +++ b/tests/test_indentation.py @@ -1,20 +1,20 @@ - -import dictdatabase as DDB -import orjson import json -from dictdatabase import utils, io_unsafe, config, io_bytes +import orjson import pytest +import dictdatabase as DDB +from dictdatabase import config, io_bytes, io_unsafe, utils + data = { - 'a': 1, - 'b': { - 'c': 2, + "a": 1, + "b": { + "c": 2, "cl": [1, "\\"], - 'd': { - 'e': 3, + "d": { + "e": 3, "el": [1, "\\"], - } + }, }, "l": [1, "\\"], } @@ -27,8 +27,6 @@ def string_dump(db: dict): return orjson.dumps(db, option=option) - - def test_indentation(use_compression, use_orjson, indent): DDB.at("test_indentation").create(data, force_overwrite=True) diff --git a/tests/test_indexer.py b/tests/test_indexer.py index 1be79d8..e4b5121 100644 --- a/tests/test_indexer.py +++ b/tests/test_indexer.py @@ -2,9 +2,7 @@ def test_indexer(use_compression, use_orjson, indent): - DDB.at("test_indexer").create(force_overwrite=True, data={ - "a": {"e": 4}, "b": 2 - }) + DDB.at("test_indexer").create(force_overwrite=True, data={"a": {"e": 4}, "b": 2}) # Trigger create index entry for key "a" assert DDB.at("test_indexer", key="a").read() == {"e": 4} diff --git a/tests/test_io_bytes.py b/tests/test_io_bytes.py index 83545cf..29a0202 100644 --- a/tests/test_io_bytes.py +++ b/tests/test_io_bytes.py @@ -1,48 +1,47 @@ -from dictdatabase import io_bytes import pytest +from dictdatabase import io_bytes def test_write_bytes(name_of_test, use_compression): - # No partial writing to compressed file allowed - if use_compression: - with pytest.raises(RuntimeError): - io_bytes.write(name_of_test, b"test", start=5) - return - # Write shorter content at index - io_bytes.write(name_of_test, b"0123456789") - io_bytes.write(name_of_test, b"abc", start=2) - assert io_bytes.read(name_of_test) == b"01abc" - # Overwrite with shorter content - io_bytes.write(name_of_test, b"xy") - assert io_bytes.read(name_of_test) == b"xy" - # Overwrite with longer content - io_bytes.write(name_of_test, b"0123456789") - io_bytes.write(name_of_test, b"abcdef", start=8) - assert io_bytes.read(name_of_test) == b"01234567abcdef" - # Write at index out of range - io_bytes.write(name_of_test, b"01") - io_bytes.write(name_of_test, b"ab", start=4) - assert io_bytes.read(name_of_test) == b'01\x00\x00ab' - + # No partial writing to compressed file allowed + if use_compression: + with pytest.raises(RuntimeError): + io_bytes.write(name_of_test, b"test", start=5) + return + # Write shorter content at index + io_bytes.write(name_of_test, b"0123456789") + io_bytes.write(name_of_test, b"abc", start=2) + assert io_bytes.read(name_of_test) == b"01abc" + # Overwrite with shorter content + io_bytes.write(name_of_test, b"xy") + assert io_bytes.read(name_of_test) == b"xy" + # Overwrite with longer content + io_bytes.write(name_of_test, b"0123456789") + io_bytes.write(name_of_test, b"abcdef", start=8) + assert io_bytes.read(name_of_test) == b"01234567abcdef" + # Write at index out of range + io_bytes.write(name_of_test, b"01") + io_bytes.write(name_of_test, b"ab", start=4) + assert io_bytes.read(name_of_test) == b"01\x00\x00ab" def test_read_bytes(name_of_test, use_compression): - io_bytes.write(name_of_test, b"0123456789") - # In range - assert io_bytes.read(name_of_test, start=2, end=5) == b"234" - # Normal ranges - assert io_bytes.read(name_of_test, start=0, end=10) == b"0123456789" - assert io_bytes.read(name_of_test, start=2) == b"23456789" - assert io_bytes.read(name_of_test, end=2) == b"01" - assert io_bytes.read(name_of_test) == b"0123456789" - # End out of range - assert io_bytes.read(name_of_test, start=9, end=20) == b"9" - # Completely out of range - assert io_bytes.read(name_of_test, start=25, end=30) == b"" - # Start negative - if use_compression: - assert io_bytes.read(name_of_test, start=-5, end=3) == b"" - else: - with pytest.raises(OSError): - io_bytes.read(name_of_test, start=-5, end=3) + io_bytes.write(name_of_test, b"0123456789") + # In range + assert io_bytes.read(name_of_test, start=2, end=5) == b"234" + # Normal ranges + assert io_bytes.read(name_of_test, start=0, end=10) == b"0123456789" + assert io_bytes.read(name_of_test, start=2) == b"23456789" + assert io_bytes.read(name_of_test, end=2) == b"01" + assert io_bytes.read(name_of_test) == b"0123456789" + # End out of range + assert io_bytes.read(name_of_test, start=9, end=20) == b"9" + # Completely out of range + assert io_bytes.read(name_of_test, start=25, end=30) == b"" + # Start negative + if use_compression: + assert io_bytes.read(name_of_test, start=-5, end=3) == b"" + else: + with pytest.raises(OSError): + io_bytes.read(name_of_test, start=-5, end=3) diff --git a/tests/test_io_safe.py b/tests/test_io_safe.py index 7f433d2..0ac3aaf 100644 --- a/tests/test_io_safe.py +++ b/tests/test_io_safe.py @@ -1,28 +1,30 @@ +import json + +import pytest + import dictdatabase as DDB from dictdatabase import io_safe -import pytest -import json def test_read(use_compression, use_orjson, indent): - # Elicit read error - DDB.config.use_orjson = True - with pytest.raises(json.decoder.JSONDecodeError): - with open(f"{DDB.config.storage_directory}/corrupted_json.json", "w") as f: - f.write("This is not JSON") - io_safe.read("corrupted_json") + # Elicit read error + DDB.config.use_orjson = True + with pytest.raises(json.decoder.JSONDecodeError): + with open(f"{DDB.config.storage_directory}/corrupted_json.json", "w") as f: + f.write("This is not JSON") + io_safe.read("corrupted_json") def test_partial_read(use_compression, use_orjson, indent): - assert io_safe.partial_read("nonexistent", key="none") is None + assert io_safe.partial_read("nonexistent", key="none") is None def test_write(use_compression, use_orjson, indent): - with pytest.raises(TypeError): - io_safe.write("nonexistent", lambda x: x) + with pytest.raises(TypeError): + io_safe.write("nonexistent", lambda x: x) def test_delete(use_compression, use_orjson, indent): - DDB.at("to_be_deleted").create() - DDB.at("to_be_deleted").delete() - assert DDB.at("to_be_deleted").read() is None + DDB.at("to_be_deleted").create() + DDB.at("to_be_deleted").delete() + assert DDB.at("to_be_deleted").read() is None diff --git a/tests/test_locking.py b/tests/test_locking.py index 58f07ad..dc3e013 100644 --- a/tests/test_locking.py +++ b/tests/test_locking.py @@ -1,8 +1,10 @@ -from dictdatabase import locking -import pytest import threading import time +import pytest + +from dictdatabase import locking + def test_double_lock_exception(use_compression): name = "test_double_lock_exception" @@ -36,9 +38,6 @@ def test_get_lock_names(use_compression): lock._unlock() - - - def test_remove_orphaned_locks(): prev_config = locking.LOCK_TIMEOUT locking.LOCK_TIMEOUT = 0.1 diff --git a/tests/test_parallel_crud.py b/tests/test_parallel_crud.py index a1293be..335b6f7 100644 --- a/tests/test_parallel_crud.py +++ b/tests/test_parallel_crud.py @@ -10,7 +10,7 @@ def do_create(name_of_test: str, return_dict: dict, id_counter: dict, operations key = f"{id_counter['id']}" db[key] = {"counter": 0} id_counter["id"] += 1 - operations['create'] += 1 + operations["create"] += 1 session.write() return_dict["created_ids"] += [key] @@ -20,7 +20,7 @@ def do_update(name_of_test: str, return_dict: dict, operations: dict) -> None: with DDB.at(name_of_test).session() as (session, db): key = random.choice(return_dict["created_ids"]) db[key]["counter"] += 1 - operations['increment'] += 1 + operations["increment"] += 1 session.write() @@ -29,7 +29,7 @@ def do_delete(name_of_test: str, return_dict: dict, operations: dict) -> None: with DDB.at(name_of_test).session() as (session, db): key = random.choice(return_dict["created_ids"]) operations["increment"] -= db[key]["counter"] - operations['delete'] += 1 + operations["delete"] += 1 db.pop(key) return_dict["created_ids"] = [i for i in return_dict["created_ids"] if i != key] session.write() @@ -39,7 +39,7 @@ def do_read(name_of_test: str, return_dict: dict, operations: dict) -> None: # read a counter key = random.choice(return_dict["created_ids"]) DDB.at(name_of_test, key=key).read() - operations['read'] += 1 + operations["read"] += 1 def worker_process(name_of_test: str, i: int, return_dict: dict, id_counter: dict) -> None: @@ -47,9 +47,9 @@ def worker_process(name_of_test: str, i: int, return_dict: dict, id_counter: dic random.seed(i) DDB.config.storage_directory = ".ddb_bench_threaded" operations = { - 'create': 0, - 'increment': 0, - 'read': 0, + "create": 0, + "increment": 0, + "read": 0, "delete": 0, } @@ -68,12 +68,9 @@ def worker_process(name_of_test: str, i: int, return_dict: dict, id_counter: dic return_dict[i] = operations - def test_multiprocessing_crud(name_of_test, use_compression, use_orjson): pre_fill_count = 500 - DDB.at(name_of_test).create({ - f"{i}": {"counter": 0} for i in range(pre_fill_count) - }, force_overwrite=True) + DDB.at(name_of_test).create({f"{i}": {"counter": 0} for i in range(pre_fill_count)}, force_overwrite=True) manager = Manager() return_dict = manager.dict() @@ -96,9 +93,9 @@ def test_multiprocessing_crud(name_of_test, use_compression, use_orjson): db_state = DDB.at(name_of_test).read() - logged_increment_ops = sum(x['increment'] for k, x in return_dict.items() if k != "created_ids") - assert logged_increment_ops == sum(x['counter'] for x in db_state.values()) + logged_increment_ops = sum(x["increment"] for k, x in return_dict.items() if k != "created_ids") + assert logged_increment_ops == sum(x["counter"] for x in db_state.values()) - logged_create_ops = sum(x['create'] for k, x in return_dict.items() if k != "created_ids") - logged_delete_ops = sum(x['delete'] for k, x in return_dict.items() if k != "created_ids") + logged_create_ops = sum(x["create"] for k, x in return_dict.items() if k != "created_ids") + logged_delete_ops = sum(x["delete"] for k, x in return_dict.items() if k != "created_ids") assert pre_fill_count + logged_create_ops - logged_delete_ops == len(db_state.keys()) diff --git a/tests/test_parallel_sessions.py b/tests/test_parallel_sessions.py index 044c5fa..9aa4883 100644 --- a/tests/test_parallel_sessions.py +++ b/tests/test_parallel_sessions.py @@ -1,7 +1,9 @@ -import dictdatabase as DDB -from path_dict import pd from multiprocessing.pool import Pool +from path_dict import pd + +import dictdatabase as DDB + def increment_counters(n, tables, cfg): DDB.config.storage_directory = cfg.storage_directory @@ -77,11 +79,6 @@ def test_heavy_multiprocessing(): assert db["counter"] == threads * per_thread - - - - - def read_partial(n, cfg): DDB.locking.SLEEP_TIMEOUT = 0 DDB.config = cfg diff --git a/tests/test_partial.py b/tests/test_partial.py index dff2e24..ef55de6 100644 --- a/tests/test_partial.py +++ b/tests/test_partial.py @@ -1,7 +1,9 @@ -import dictdatabase as DDB -from path_dict import pd import json + import pytest +from path_dict import pd + +import dictdatabase as DDB def test_subread(use_compression, use_orjson, indent): @@ -17,7 +19,6 @@ def test_subread(use_compression, use_orjson, indent): assert DDB.at(name, key="a").read() == "Hello{}" assert DDB.at(name, where=lambda k, v: isinstance(v, list)).read() == {"b": [0, 1]} - assert DDB.at(name, key="f").read() is None assert DDB.at(name, key="b").read() == [0, 1] diff --git a/tests/test_read.py b/tests/test_read.py index 0601143..d941690 100644 --- a/tests/test_read.py +++ b/tests/test_read.py @@ -1,7 +1,8 @@ +import json -import dictdatabase as DDB import pytest -import json + +import dictdatabase as DDB from tests.utils import make_complex_nested_random_dict @@ -24,10 +25,6 @@ def test_invalid_params(use_compression, use_orjson, indent): DDB.at("test_invalid_params", key="any", where=lambda k, v: True).read() - - - - def test_read_integrity(use_compression, use_orjson, indent): cases = [ r'{"a": "\\", "b": 0}', @@ -51,11 +48,6 @@ def test_read_integrity(use_compression, use_orjson, indent): assert key_b == json.loads(case)["b"] - - - - - def test_create_and_read(use_compression, use_orjson, indent): name = "test_create_and_read" d = make_complex_nested_random_dict(12, 6) diff --git a/tests/test_threaded_sessions.py b/tests/test_threaded_sessions.py index 0f0e21a..67e91ca 100644 --- a/tests/test_threaded_sessions.py +++ b/tests/test_threaded_sessions.py @@ -1,7 +1,9 @@ -import dictdatabase as DDB -from path_dict import pd from concurrent.futures import ThreadPoolExecutor, wait +from path_dict import pd + +import dictdatabase as DDB + def increment_counters(n, tables): for _ in range(n): diff --git a/tests/test_utils.py b/tests/test_utils.py index af5cd91..9033c96 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,8 @@ import itertools + import orjson -from dictdatabase import utils, byte_codes + +from dictdatabase import byte_codes, utils def test_seek_index_through_value_bytes(): @@ -21,7 +23,7 @@ def load_with_orjson(bytes, key): return orjson.loads(bytes)[key] def load_with_seeker(bytes, key): - key_bytes = f"\"{key}\":".encode() + key_bytes = f'"{key}":'.encode() a_val_start = bytes.find(key_bytes) + len(key_bytes) if bytes[a_val_start] == byte_codes.SPACE: a_val_start += 1 @@ -65,14 +67,14 @@ def load_with_seeker(bytes, key): "a\\b", "\\", "\\\\", - "\\\\\"", - "\\\"\\", - "\"\\\\", - "\"", - "\"\"", - "\"\"\\", - "\"\\\"", - "\\\"\"", + '\\\\"', + '\\"\\', + '"\\\\', + '"', + '""', + '""\\', + '"\\"', + '\\""', # Booleans True, None, diff --git a/tests/test_where.py b/tests/test_where.py index 21d89a8..4238db6 100644 --- a/tests/test_where.py +++ b/tests/test_where.py @@ -1,22 +1,23 @@ -import dictdatabase as DDB -from path_dict import PathDict import pytest +from path_dict import PathDict + +import dictdatabase as DDB def test_where(use_compression, use_orjson, indent): - for i in range(10): - DDB.at("test_select", i).create({"a": i}, force_overwrite=True) + for i in range(10): + DDB.at("test_select", i).create({"a": i}, force_overwrite=True) - s = DDB.at("test_select/*", where=lambda k, v: v["a"] > 7).read() + s = DDB.at("test_select/*", where=lambda k, v: v["a"] > 7).read() - assert s == {"8": {"a": 8}, "9": {"a": 9}} + assert s == {"8": {"a": 8}, "9": {"a": 9}} - with pytest.raises(KeyError): - DDB.at("test_select/*", where=lambda k, v: v["b"] > 5).read() + with pytest.raises(KeyError): + DDB.at("test_select/*", where=lambda k, v: v["b"] > 5).read() - assert DDB.at("nonexistent/*", where=lambda k, v: v["a"] > 5).read() == {} + assert DDB.at("nonexistent/*", where=lambda k, v: v["a"] > 5).read() == {} - assert DDB.at("nonexistent", where=lambda k, v: v["a"] > 5).read() is None + assert DDB.at("nonexistent", where=lambda k, v: v["a"] > 5).read() is None - s = DDB.at("test_select/*", where=lambda k, v: v.at("a").get() > 7).read(as_type=PathDict) - assert s.get() == {"8": {"a": 8}, "9": {"a": 9}} + s = DDB.at("test_select/*", where=lambda k, v: v.at("a").get() > 7).read(as_type=PathDict) + assert s.get() == {"8": {"a": 8}, "9": {"a": 9}} diff --git a/tests/test_write.py b/tests/test_write.py index 13f34d6..67df936 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -1,6 +1,7 @@ -import dictdatabase as DDB -from path_dict import pd import pytest +from path_dict import pd + +import dictdatabase as DDB from tests.utils import make_complex_nested_random_dict diff --git a/tests/utils.py b/tests/utils.py index 3606207..53062aa 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,7 @@ -import random -import string import json import os +import random +import string def get_tasks_json(): @@ -11,7 +11,6 @@ def get_tasks_json(): def make_complex_nested_random_dict(max_width, max_depth): - def random_string(choices, md): length = random.randint(0, max_width) letters = string.ascii_letters + "".join(["\\", " ", "🚀", '"']) @@ -48,12 +47,6 @@ def random_dict(choices, md): res[k] = v return res - return random_dict([ - random_string, - random_int, - random_float, - random_bool, - random_none, - random_list, - random_dict - ], max_depth) + return random_dict( + [random_string, random_int, random_float, random_bool, random_none, random_list, random_dict], max_depth + )