diff --git a/superscore/backends/filestore.py b/superscore/backends/filestore.py index 07e4348..7d317cd 100644 --- a/superscore/backends/filestore.py +++ b/superscore/backends/filestore.py @@ -8,7 +8,8 @@ import os import shutil from dataclasses import fields, replace -from typing import Any, Dict, Generator, Optional, Union +from functools import cache +from typing import Any, Container, Dict, Generator, Optional, Union from uuid import UUID, uuid4 from apischema import deserialize, serialize @@ -293,7 +294,7 @@ def search(self, *search_terms: SearchTermType) -> Generator[Entry, None, None]: if attr == "entry_type": conditions.append(isinstance(entry, target)) elif attr == "ancestor": - conditions.append(self._is_ancestor(target, entry.uuid)) + conditions.append(entry.uuid in self._gather_progeny(target)) else: try: # check entry attribute by name @@ -304,17 +305,25 @@ def search(self, *search_terms: SearchTermType) -> Generator[Entry, None, None]: if all(conditions): yield entry - def _is_ancestor(self, ancestor: UUID, entry: UUID): + @cache + def _gather_progeny(self, ancestor: UUID) -> Container[UUID]: + """ + Finds all entries accessible from ancestor, and returns their UUIDs. This + makes it easy to check if one entry is hierarchically under another. + + This method is cached to keep runtimes low when checking multiple entries + against the same ancestor. To keep data valid, clear this method's cache + """ + progeny = set() q = [ancestor] while len(q) > 0: - search_entry = q.pop() - if not isinstance(search_entry, Entry): - search_entry = self._entry_cache[search_entry] - if search_entry.uuid == entry: - return True - elif isinstance(search_entry, Nestable): - q.extend(search_entry.children) - return False + cur = q.pop() + if not isinstance(cur, Entry): + cur = self._entry_cache[cur] + progeny.add(cur.uuid) + if isinstance(cur, Nestable): + q.extend(cur.children) + return progeny @contextlib.contextmanager def _load_and_store_context(self) -> Generator[Dict[UUID, Any], None, None]: diff --git a/superscore/tests/test_backend.py b/superscore/tests/test_backend.py index 288c46f..066f70e 100644 --- a/superscore/tests/test_backend.py +++ b/superscore/tests/test_backend.py @@ -205,12 +205,11 @@ def test_update_entry(backends: _Backend): @pytest.mark.parametrize("filestore_backend", [("linac_data",)], indirect=True) -def test_is_ancestor(filestore_backend: _Backend): - assert filestore_backend._is_ancestor( - UUID("06282731-33ea-4270-ba14-098872e627dc"), - UUID("927ef6cb-e45f-4175-aa5f-6c6eec1f3ae4") - ) # top-level snapshot - assert filestore_backend._is_ancestor( - UUID("2f709b4b-79da-4a8b-8693-eed2c389cb3a"), - UUID("927ef6cb-e45f-4175-aa5f-6c6eec1f3ae4") - ) # direct parent snapshot +def test_gather_progeny(filestore_backend: _Backend): + # top-level snapshot + progeny = filestore_backend._gather_progeny(UUID("06282731-33ea-4270-ba14-098872e627dc")) + assert UUID("927ef6cb-e45f-4175-aa5f-6c6eec1f3ae4") in progeny + + # direct parent snapshot + progeny = filestore_backend._gather_progeny(UUID("2f709b4b-79da-4a8b-8693-eed2c389cb3a")) + assert UUID("927ef6cb-e45f-4175-aa5f-6c6eec1f3ae4") in progeny