diff --git a/dictdatabase/io_bytes.py b/dictdatabase/io_bytes.py index 7cc0edb..d5a87a9 100644 --- a/dictdatabase/io_bytes.py +++ b/dictdatabase/io_bytes.py @@ -4,7 +4,7 @@ -def read(db_name: str, start=None, end=None) -> bytes: +def read(db_name: str, start: int = None, end: int = None) -> bytes: """ Read the content of a file as bytes. Reading works even when the config changes, so a compressed ddb file can also be read if compression is @@ -13,10 +13,11 @@ def read(db_name: str, start=None, end=None) -> bytes: Note: Only specify either both start and end, or none of them. Args: - - `db_name`: The name of the database to read from. - - `start`: The start index to read from. - - `end`: The end index to read up to (not included). + - `db_name`: The name of the database file to read from. + - `start`: The start byte index to read from. + - `end`: The end byte index to read up to (not included). """ + json_path, json_exists, ddb_path, ddb_exists = utils.file_info(db_name) if json_exists: @@ -39,7 +40,7 @@ def read(db_name: str, start=None, end=None) -> bytes: -def write(db_name: str, dump: bytes, start=None): +def write(db_name: str, dump: bytes, start: int = None): """ Write the bytes to the file of the db_path. If the db was compressed but no compression is enabled, remove the compressed file, and vice versa. @@ -48,6 +49,8 @@ def write(db_name: str, dump: bytes, start=None): - `db_name`: The name of the database to write to. - `dump`: The bytes to write to the file, representing correct JSON when decoded. + - `start`: The start byte index to write to. If None, the whole file is overwritten. + If the original content was longer, the rest truncated. """ json_path, json_exists, ddb_path, ddb_exists = utils.file_info(db_name) diff --git a/tests/test_io_bytes.py b/tests/test_io_bytes.py new file mode 100644 index 0000000..1e43841 --- /dev/null +++ b/tests/test_io_bytes.py @@ -0,0 +1,42 @@ +import pytest +from dictdatabase import io_bytes + + + +def test_write_bytes(use_test_dir, name_of_test): + # Write shorter content at index + io_bytes.write(name_of_test, b"0123456789") + io_bytes.write(name_of_test, b"abc", 2) + assert io_bytes.read(name_of_test) == b"01abc" + # Overwrite with shorter content + io_bytes.write(name_of_test, b"xy") + assert io_bytes.read(name_of_test) == b"xy" + # Overwrite with longer content + io_bytes.write(name_of_test, b"0123456789") + io_bytes.write(name_of_test, b"abcdef", 8) + assert io_bytes.read(name_of_test) == b"01234567abcdef" + # Write at index out of range + io_bytes.write(name_of_test, b"01") + io_bytes.write(name_of_test, b"ab", 4) + assert io_bytes.read(name_of_test) == b'01\x00\x00ab' + + + +def test_read_bytes(use_test_dir, name_of_test, use_compression): + io_bytes.write(name_of_test, b"0123456789") + # In range + assert io_bytes.read(name_of_test, 2, 5) == b"234" + # Complete range + assert io_bytes.read(name_of_test, 0, 10) == b"0123456789" + assert io_bytes.read(name_of_test, 0, None) == b"0123456789" + assert io_bytes.read(name_of_test) == b"0123456789" + # End out of range + assert io_bytes.read(name_of_test, 9, 20) == b"9" + # Completely out of range + assert io_bytes.read(name_of_test, 25, 30) == b"" + # Start negative + if use_compression: + assert io_bytes.read(name_of_test, -5, 3) == b"" + else: + with pytest.raises(OSError): + io_bytes.read(name_of_test, -5, 3)