diff --git a/dictdatabase/io_bytes.py b/dictdatabase/io_bytes.py index d5a87a9..5a7105a 100644 --- a/dictdatabase/io_bytes.py +++ b/dictdatabase/io_bytes.py @@ -4,43 +4,61 @@ -def read(db_name: str, start: int = None, end: int = None) -> bytes: +def read(db_name: str, *, start: int = None, end: int = None) -> bytes: """ Read the content of a file as bytes. Reading works even when the config changes, so a compressed ddb file can also be read if compression is disabled, and vice versa. - Note: Only specify either both start and end, or none of them. + If no compression is used, efficient reading can be done by specifying a start + and end byte index, such that only the bytes in that range are read from the + file. + + If compression is used, specifying a start and end byte index is still possible, + but the entire file has to be read and decompressed first, and then the bytes + in the range are returned. This is because the compressed file is not seekable. Args: - `db_name`: The name of the database file to read from. - `start`: The start byte index to read from. - `end`: The end byte index to read up to (not included). + + Raises: + - `FileNotFoundError`: If the file does not exist as .json nor .ddb. + - `OSError`: If no compression is used and `start` is negative. + - `FileExistsError`: If the file exists as .json and .ddb. """ json_path, json_exists, ddb_path, ddb_exists = utils.file_info(db_name) if json_exists: if ddb_exists: - raise FileExistsError(f"Inconsistent: \"{db_name}\" exists as .json and .ddb") + raise FileExistsError( + f"Inconsistent: \"{db_name}\" exists as .json and .ddb." + "Please remove one of them." + ) with open(json_path, "rb") as f: - if start is None: + if start is None and end is None: return f.read() + start = start or 0 f.seek(start) if end is None: return f.read() return f.read(end - start) if not ddb_exists: - raise FileNotFoundError(f"DB does not exist: \"{db_name}\"") + raise FileNotFoundError(f"No database file exists for \"{db_name}\"") with open(ddb_path, "rb") as f: json_bytes = zlib.decompress(f.read()) - if start is None: + if start is None and end is None: return json_bytes + start = start or 0 + end = end or len(json_bytes) return json_bytes[start:end] -def write(db_name: str, dump: bytes, start: int = None): + +def write(db_name: str, dump: bytes, *, start: int = None): """ Write the bytes to the file of the db_path. If the db was compressed but no compression is enabled, remove the compressed file, and vice versa. diff --git a/dictdatabase/io_unsafe.py b/dictdatabase/io_unsafe.py index d2ea320..7204a30 100644 --- a/dictdatabase/io_unsafe.py +++ b/dictdatabase/io_unsafe.py @@ -56,7 +56,7 @@ def try_read_bytes_using_indexer(indexer: indexing.Indexer, db_name: str, key: s if (index := indexer.get(key)) is None: return None start, end, _, _, value_hash = index - partial_bytes = io_bytes.read(db_name, start, end) + partial_bytes = io_bytes.read(db_name, start=start, end=end) if value_hash != hashlib.sha256(partial_bytes).hexdigest(): return None return partial_bytes @@ -155,12 +155,12 @@ def try_get_parial_file_handle_by_index(indexer: indexing.Indexer, db_name, key) # If compression is disabled, only the value and suffix have to be read else: - value_and_suffix_bytes = io_bytes.read(db_name, start) + value_and_suffix_bytes = io_bytes.read(db_name, start=start) value_length = end - start value_bytes = value_and_suffix_bytes[:value_length] if value_hash != hashlib.sha256(value_bytes).hexdigest(): # If the hashes don't match, read the prefix to concat the full file bytes - prefix_bytes = io_bytes.read(db_name, 0, start) + prefix_bytes = io_bytes.read(db_name, end=start) return None, prefix_bytes + value_and_suffix_bytes value_data = orjson.loads(value_bytes) partial_dict = PartialDict(None, key, value_data, start, end, value_and_suffix_bytes[value_length:]) diff --git a/tests/test_io_bytes.py b/tests/test_io_bytes.py index 1e43841..a9884f6 100644 --- a/tests/test_io_bytes.py +++ b/tests/test_io_bytes.py @@ -1,23 +1,28 @@ -import pytest from dictdatabase import io_bytes +import pytest -def test_write_bytes(use_test_dir, name_of_test): +def test_write_bytes(use_test_dir, name_of_test, use_compression): + # No partial writing to compressed file allowed + if use_compression: + with pytest.raises(RuntimeError): + io_bytes.write(name_of_test, b"test", start=5) + return # Write shorter content at index io_bytes.write(name_of_test, b"0123456789") - io_bytes.write(name_of_test, b"abc", 2) + io_bytes.write(name_of_test, b"abc", start=2) assert io_bytes.read(name_of_test) == b"01abc" # Overwrite with shorter content io_bytes.write(name_of_test, b"xy") assert io_bytes.read(name_of_test) == b"xy" # Overwrite with longer content io_bytes.write(name_of_test, b"0123456789") - io_bytes.write(name_of_test, b"abcdef", 8) + io_bytes.write(name_of_test, b"abcdef", start=8) assert io_bytes.read(name_of_test) == b"01234567abcdef" # Write at index out of range io_bytes.write(name_of_test, b"01") - io_bytes.write(name_of_test, b"ab", 4) + io_bytes.write(name_of_test, b"ab", start=4) assert io_bytes.read(name_of_test) == b'01\x00\x00ab' @@ -25,18 +30,19 @@ def test_write_bytes(use_test_dir, name_of_test): def test_read_bytes(use_test_dir, name_of_test, use_compression): io_bytes.write(name_of_test, b"0123456789") # In range - assert io_bytes.read(name_of_test, 2, 5) == b"234" - # Complete range - assert io_bytes.read(name_of_test, 0, 10) == b"0123456789" - assert io_bytes.read(name_of_test, 0, None) == b"0123456789" + assert io_bytes.read(name_of_test, start=2, end=5) == b"234" + # Normal ranges + assert io_bytes.read(name_of_test, start=0, end=10) == b"0123456789" + assert io_bytes.read(name_of_test, start=2) == b"23456789" + assert io_bytes.read(name_of_test, end=2) == b"01" assert io_bytes.read(name_of_test) == b"0123456789" # End out of range - assert io_bytes.read(name_of_test, 9, 20) == b"9" + assert io_bytes.read(name_of_test, start=9, end=20) == b"9" # Completely out of range - assert io_bytes.read(name_of_test, 25, 30) == b"" + assert io_bytes.read(name_of_test, start=25, end=30) == b"" # Start negative if use_compression: - assert io_bytes.read(name_of_test, -5, 3) == b"" + assert io_bytes.read(name_of_test, start=-5, end=3) == b"" else: with pytest.raises(OSError): - io_bytes.read(name_of_test, -5, 3) + io_bytes.read(name_of_test, start=-5, end=3)