diff --git a/dictdatabase/dataclasses.py b/dictdatabase/dataclasses.py new file mode 100644 index 0000000..2c54e19 --- /dev/null +++ b/dictdatabase/dataclasses.py @@ -0,0 +1,8 @@ +import dataclasses + + +@dataclasses.dataclass(frozen=True) +class SearchResult: + start_byte: int + end_byte: int + found: bool diff --git a/dictdatabase/io_unsafe.py b/dictdatabase/io_unsafe.py index fbc405b..7c8cba2 100644 --- a/dictdatabase/io_unsafe.py +++ b/dictdatabase/io_unsafe.py @@ -185,16 +185,16 @@ def get_partial_file_handle(db_name: str, key: str) -> PartialFileHandle: return partial_handle # Not found in index file, search for key in the entire file - key_start, key_end, found = searching.search_key_position_in_db(all_file_bytes, key) + position = searching.search_key_position_in_db(all_file_bytes, key) - if not found: + if not position.found: raise KeyError(f"Key \"{key}\" not found in db \"{db_name}\"") # Key found, now determine the bounding byte indices of the value - start = key_end + (1 if all_file_bytes[key_end] == byte_codes.SPACE else 0) + start = position.end_byte + (1 if all_file_bytes[position.end_byte] == byte_codes.SPACE else 0) end = utils.seek_index_through_value_bytes(all_file_bytes, start) - indent_level, indent_with = utils.detect_indentation_in_json_bytes(all_file_bytes, key_start) + indent_level, indent_with = utils.detect_indentation_in_json_bytes(all_file_bytes, position.start_byte) partial_value = orjson.loads(all_file_bytes[start:end]) prefix_bytes = all_file_bytes[:start] if config.use_compression else None diff --git a/dictdatabase/searching.py b/dictdatabase/searching.py index 62ed857..697819b 100644 --- a/dictdatabase/searching.py +++ b/dictdatabase/searching.py @@ -4,9 +4,10 @@ from dictdatabase import byte_codes from dictdatabase import utils +from dictdatabase.dataclasses import SearchResult -def find_key_position_in_bytes(file: bytes, key: str) -> Tuple[int, int, bool]: +def find_key_position_in_bytes(file: bytes, key: str) -> SearchResult: """ It finds the start and end indices of the value of a key in a JSON file @@ -19,13 +20,15 @@ def find_key_position_in_bytes(file: bytes, key: str) -> Tuple[int, int, bool]: """ key_start, key_end = utils.find_outermost_key_in_json_bytes(file, key) if key_end == -1: - return -1, -1, False + return SearchResult(start_byte=-1, end_byte=-1, found=False) start = key_end + (1 if file[key_end] == byte_codes.SPACE else 0) end = utils.seek_index_through_value_bytes(file, start) - return start, end, True + return SearchResult(start_byte=start, end_byte=end, found=True) -def search_key_position_in_db(file: bytes, key: str, glom_searching=True) -> Tuple[int, int, bool]: +def search_key_position_in_db( + file: bytes, key: str, glom_searching=True +) -> SearchResult: original_value_start = 0 original_value_end = len(file) original_key_start = 0 @@ -33,14 +36,14 @@ def search_key_position_in_db(file: bytes, key: str, glom_searching=True) -> Tup for k in key.split(".") if glom_searching else [key]: key_start, key_end = utils.find_outermost_key_in_json_bytes(file, k) if key_end == -1: - return -1, -1, False + return SearchResult(start_byte=-1, end_byte=-1, found=False) original_key_end = original_value_start + key_end original_key_start = original_value_start + key_start - value_start, value_end, found = find_key_position_in_bytes(file, k) + position = find_key_position_in_bytes(file, k) original_value_end = original_value_start + original_value_end - original_value_start += value_start + original_value_start += position.start_byte file = file[original_value_start:original_value_end] - return original_key_start, original_key_end, True + return SearchResult(start_byte=original_key_start, end_byte=original_key_end, found=True) def search_value_position_in_db( @@ -61,11 +64,11 @@ def search_value_position_in_db( original_start = 0 original_end = len(all_file_bytes) for k in key.split(".") if glom_searching else [key]: - start, end, found = find_key_position_in_bytes( + position = find_key_position_in_bytes( all_file_bytes[original_start:original_end], k ) - if not found: + if not position.found: return -1, -1, False - original_end = original_start + end - original_start += start + original_end = original_start + position.end_byte + original_start += position.start_byte return original_start, original_end, True diff --git a/tests/test_glom_writing.py b/tests/test_glom_writing.py index 2702884..4255cd0 100644 --- a/tests/test_glom_writing.py +++ b/tests/test_glom_writing.py @@ -1,3 +1,5 @@ +import pytest + import dictdatabase as DDB data = { @@ -15,3 +17,11 @@ def test_glom_writing(): purchase["status"] = "cancelled" session.write() assert DDB.at("users", key="users.Ben.status").read() == "cancelled" + + +def test_glom_writing_sub_key_not_exists(): + DDB.at("users").create(data, force_overwrite=True) + with pytest.raises(KeyError): + with DDB.at("users", key="users.SUBKEY").session() as (session, purchase): + purchase["status"] = "cancelled" + session.write()