Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix prefetch_relationships to work with hierarchical nodes #4672

Draft
wants to merge 1 commit into
base: stable
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions backend/infrahub/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from infrahub.core.query.relationship import RelationshipGetPeerQuery
from infrahub.core.registry import registry
from infrahub.core.relationship import Relationship
from infrahub.core.relationship import Relationship, RelationshipManager
from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, RelationshipSchema
from infrahub.core.timestamp import Timestamp
from infrahub.exceptions import NodeNotFoundError, ProcessingError, SchemaNotFoundError
Expand Down Expand Up @@ -1138,8 +1138,8 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements

# if prefetch_relationships is enabled
# Query all the peers associated with all nodes at once.
peers_per_node = None
peers = None
peers_per_node: dict[str, dict[str, list[str]]] = {}
peers: dict[str, Node] = {}
if prefetch_relationships:
query = await NodeListGetRelationshipsQuery.init(
db=db, ids=ids, branch=branch, at=at, branch_agnostic=branch_agnostic
Expand All @@ -1152,7 +1152,8 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements
for node_peers in node_data.values():
peer_ids.extend(node_peers)

peer_ids = list(set(peer_ids))
# query the peers that are not already part of the main list
peer_ids = list(set(peer_ids) - set(ids))
peers = await cls.get_many(
ids=peer_ids,
branch=branch,
Expand All @@ -1162,7 +1163,7 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements
include_source=include_source,
)

nodes = {}
nodes: dict[str, Node] = {}

for node_id in ids: # pylint: disable=too-many-nested-blocks
if node_id not in nodes_info_by_id:
Expand All @@ -1189,19 +1190,6 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements
for attr_name, attr in node_attributes[node_id].attrs.items():
new_node_data[attr_name] = attr

# --------------------------------------------------------
# Relationships
# --------------------------------------------------------
if prefetch_relationships and peers:
for rel_schema in node.schema.relationships:
if node_id in peers_per_node and rel_schema.identifier in peers_per_node[node_id]:
rel_peers = [peers.get(id) for id in peers_per_node[node_id][rel_schema.identifier]]
if rel_schema.cardinality == "one":
if len(rel_peers) == 1:
new_node_data[rel_schema.name] = rel_peers[0]
elif rel_schema.cardinality == "many":
new_node_data[rel_schema.name] = rel_peers

new_node_data_with_profile_overrides = profile_index.apply_profiles(new_node_data)
node_class = identify_node_class(node=node)
node_branch = await registry.get_branch(db=db, branch=node.branch)
Expand All @@ -1210,6 +1198,27 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements

nodes[node_id] = item

# --------------------------------------------------------
# Relationships
# --------------------------------------------------------
if prefetch_relationships:
for node_id, node in nodes.items():
if node_id not in peers_per_node.keys():
continue

for rel_schema in node._schema.relationships:
direction_identifier = f"{rel_schema.direction.value}::{rel_schema.identifier}"
if direction_identifier in peers_per_node[node_id]:
rel_peers = [
peers.get(id, None) or nodes.get(id) for id in peers_per_node[node_id][direction_identifier]
]
rel_manager: RelationshipManager = getattr(node, rel_schema.name)
if rel_schema.cardinality == "one" and not len(rel_peers) == 1:
raise ValueError("Only one relationship expected")

rel_manager.has_fetched_relationships = True
await rel_manager.update(db=db, data=rel_peers)

return nodes

@classmethod
Expand Down
44 changes: 30 additions & 14 deletions backend/infrahub/core/query/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def _extract_attribute_data(self, result: QueryResult) -> AttributeFromDB:

class NodeListGetRelationshipsQuery(Query):
name: str = "node_list_get_relationship"
insert_return: bool = False

def __init__(self, ids: list[str], **kwargs):
self.ids = ids
Expand All @@ -569,28 +570,43 @@ async def query_init(self, db: InfrahubDatabase, **kwargs) -> None:
rels_filter, rels_params = self.branch.get_query_filter_path(at=self.at, branch_agnostic=self.branch_agnostic)
self.params.update(rels_params)

query = (
"""
MATCH (n) WHERE n.uuid IN $ids
MATCH p = ((n)-[r1:IS_RELATED]-(rel:Relationship)-[r2:IS_RELATED]-(peer))
WHERE all(r IN relationships(p) WHERE (%s))
"""
% rels_filter
)
query = """
MATCH (n1:Node)
WHERE n1.uuid IN $ids
MATCH paths_in = ((n1)<-[r1:IS_RELATED]-(rel1:Relationship)<-[r2:IS_RELATED]-(peer1))
WHERE all(r IN relationships(paths_in) WHERE (%(filters)s))
RETURN n1 as res4_node, rel1 as res3_rel, peer1 as res2_peer, "inbound" as res1_direction
UNION ALL
MATCH (n2:Node)
WHERE n2.uuid IN $ids
MATCH paths_out = ((n2)-[r1:IS_RELATED]->(rel2:Relationship)-[r2:IS_RELATED]->(peer2))
WHERE all(r IN relationships(paths_out) WHERE (%(filters)s))
RETURN n2 as res4_node, rel2 as res3_rel, peer2 as res2_peer, "outbound" as res1_direction
UNION ALL
MATCH (n3:Node)
WHERE n3.uuid IN $ids
MATCH paths_bidir = ((n3)-[r1:IS_RELATED]->(rel3:Relationship)<-[r2:IS_RELATED]-(peer3))
WHERE all(r IN relationships(paths_bidir) WHERE (%(filters)s))
RETURN n3 as res4_node, rel3 as res3_rel, peer3 as res2_peer, "bidirectional" as res1_direction
""" % {"filters": rels_filter}

self.add_to_query(query)

self.return_labels = ["n", "rel", "peer", "r1", "r2"]
# NOTE Not sure why but when using UNION memgraph 2.19 is returning the result in alphabetically reverse order
# instead of respecting the order defined in the query
# In order to have a consistent ordering, all the results have been prepended with res<id>
self.return_labels = ["res4_node", "res3_rel", "res2_peer", "res1_direction"]

def get_peers_group_by_node(self) -> dict[str, dict[str, list[str]]]:
peers_by_node = defaultdict(lambda: defaultdict(list))

for result in self.get_results_group_by(("n", "uuid"), ("rel", "name"), ("peer", "uuid")):
node_id = result.get("n").get("uuid")
rel_name = result.get("rel").get("name")
peer_id = result.get("peer").get("uuid")
for result in self.get_results_group_by(("res4_node", "uuid"), ("res3_rel", "name"), ("res2_peer", "uuid")):
node_id = result.get_node("res4_node").get("uuid")
rel_name = result.get_node("res3_rel").get("name")
peer_id = result.get_node("res2_peer").get("uuid")
direction = result.get_as_str("res1_direction")

peers_by_node[node_id][rel_name].append(peer_id)
peers_by_node[node_id][f"{direction}::{rel_name}"].append(peer_id)

return peers_by_node

Expand Down
6 changes: 4 additions & 2 deletions backend/infrahub/core/relationship/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,11 +943,13 @@ async def _fetch_relationships(
for peer_id in details.peer_ids_present_local_only:
await self.remove(peer_id=peer_id, db=db)

async def get(self, db: InfrahubDatabase) -> Union[Relationship, list[Relationship]]:
async def get(self, db: InfrahubDatabase) -> Relationship | list[Relationship] | None:
rels = await self.get_relationships(db=db)

if self.schema.cardinality == "one":
if self.schema.cardinality == "one" and rels:
return rels[0]
if self.schema.cardinality == "one" and not rels:
return None

return rels

Expand Down
4 changes: 2 additions & 2 deletions backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,14 +2118,14 @@ async def hierarchical_location_schema(
@pytest.fixture
async def hierarchical_location_data_simple(
db: InfrahubDatabase, default_branch: Branch, hierarchical_location_schema_simple
) -> Dict[str, Node]:
) -> dict[str, Node]:
return await _build_hierarchical_location_data(db=db)


@pytest.fixture
async def hierarchical_location_data(
db: InfrahubDatabase, default_branch: Branch, hierarchical_location_schema
) -> Dict[str, Node]:
) -> dict[str, Node]:
return await _build_hierarchical_location_data(db=db)


Expand Down
27 changes: 27 additions & 0 deletions backend/tests/unit/core/test_manager_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from infrahub.core.node import Node
from infrahub.core.query.node import NodeToProcess
from infrahub.core.registry import registry
from infrahub.core.relationship import Relationship
from infrahub.core.schema import NodeSchema
from infrahub.core.schema.schema_branch import SchemaBranch
from infrahub.core.timestamp import Timestamp
Expand Down Expand Up @@ -263,6 +264,32 @@ async def test_get_many_prefetch(db: InfrahubDatabase, default_branch: Branch, p
assert tags[1]._peer


async def test_get_many_prefetch_hierarchical(
db: InfrahubDatabase, default_branch: Branch, hierarchical_location_data: dict[str, Node]
):
nodes_to_query = ["europe", "asia", "paris", "chicago", "london-r1"]
node_ids = [hierarchical_location_data[value].id for value in nodes_to_query]
nodes = await NodeManager.get_many(db=db, ids=node_ids, prefetch_relationships=True)
assert len(nodes) == 5

paris_id = hierarchical_location_data["paris"].id
europe_id = hierarchical_location_data["europe"].id

assert nodes[paris_id]
children_paris = await nodes[paris_id].children.get(db=db)
assert len(children_paris) == 2
parent_paris = await nodes[paris_id].parent.get(db=db)
assert isinstance(parent_paris, Relationship)
assert parent_paris.peer_id == europe_id

europe_id = hierarchical_location_data["europe"].id
assert nodes[europe_id]
children_europe = await nodes[europe_id].children.get(db=db)
assert len(children_europe) == 2
parent_europe = await nodes[europe_id].parent.get(db=db)
assert parent_europe is None


async def test_get_many_with_profile(db: InfrahubDatabase, default_branch: Branch, criticality_low, criticality_medium):
profile_schema = registry.schema.get("ProfileTestCriticality", branch=default_branch)
crit_profile_1 = await Node.init(db=db, schema=profile_schema)
Expand Down
23 changes: 21 additions & 2 deletions backend/tests/unit/core/test_node_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,27 @@ async def test_query_NodeListGetRelationshipsQuery(db: InfrahubDatabase, default
await query.execute(db=db)
result = query.get_peers_group_by_node()
assert person_jack_tags_main.id in result
assert "builtintag__testperson" in result[person_jack_tags_main.id]
assert len(result[person_jack_tags_main.id]["builtintag__testperson"]) == 2
assert "inbound::builtintag__testperson" in result[person_jack_tags_main.id]
assert len(result[person_jack_tags_main.id]["inbound::builtintag__testperson"]) == 2


async def test_query_NodeListGetRelationshipsQuery_hierarchical(
db: InfrahubDatabase, default_branch: Branch, hierarchical_location_data: dict[str, Node]
):
node_ids = [value.id for value in hierarchical_location_data.values()]
paris_id = hierarchical_location_data["paris"].id
default_branch = await registry.get_branch(db=db, branch="main")
query = await NodeListGetRelationshipsQuery.init(
db=db,
ids=node_ids,
branch=default_branch,
)
await query.execute(db=db)
result = query.get_peers_group_by_node()
assert paris_id in result
assert "inbound::parent__child" in result[paris_id]
assert "outbound::parent__child" in result[paris_id]
assert len(result[paris_id]["inbound::parent__child"]) == 2


async def test_query_NodeDeleteQuery(
Expand Down
Loading