From 3e8a7265f57ee6b5cb23a60a4f43084ac62205a4 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Wed, 11 Dec 2024 18:16:34 -0800 Subject: [PATCH 1/3] WIP prototype using DataLoader for single_relationship_resolver --- backend/infrahub/core/query/relationship.py | 2 +- .../dependencies/component/registry.py | 11 +- backend/infrahub/graphql/initialization.py | 6 + backend/infrahub/graphql/resolver.py | 160 ++++++++++++++---- poetry.lock | 22 ++- pyproject.toml | 1 + 6 files changed, 168 insertions(+), 34 deletions(-) diff --git a/backend/infrahub/core/query/relationship.py b/backend/infrahub/core/query/relationship.py index e540769d0d..0dde763f62 100644 --- a/backend/infrahub/core/query/relationship.py +++ b/backend/infrahub/core/query/relationship.py @@ -720,7 +720,7 @@ def get_peer_ids(self) -> list[str]: return [peer.peer_id for peer in self.get_peers()] def get_peers(self) -> Generator[RelationshipPeerData, None, None]: - for result in self.get_results_group_by(("peer", "uuid")): + for result in self.get_results_group_by(("peer", "uuid"), ("source_node", "uuid")): rels = result.get("rels") data = RelationshipPeerData( source_id=result.get_node("source_node").get("uuid"), diff --git a/backend/infrahub/dependencies/component/registry.py b/backend/infrahub/dependencies/component/registry.py index e6c0a91d40..d8f890e983 100644 --- a/backend/infrahub/dependencies/component/registry.py +++ b/backend/infrahub/dependencies/component/registry.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Optional, TypeVar from ..interface import DependencyBuilderContext from .exceptions import UntrackedDependencyError @@ -21,6 +21,7 @@ class ComponentDependencyRegistry: def __init__(self) -> None: self._available_components: dict[type, type[DependencyBuilder]] = {} + self._cached_components: dict[type, Any] = {} @classmethod def get_registry(cls) -> ComponentDependencyRegistry: @@ -34,6 +35,14 @@ async def get_component(self, component_class: type[T], db: InfrahubDatabase, br context = DependencyBuilderContext(db=db, branch=branch) return self._available_components[component_class].build(context=context) + def cache_component(self, component: Any) -> None: + self._cached_components[type(component)] = component + + def get_cached_component(self, component_class: type[T]) -> T: + if component_class not in self._cached_components: + raise UntrackedDependencyError(f"'{component_class}' is not a cached component") + return self._cached_components[component_class] + def track_dependency(self, dependency_class: type[DependencyBuilder]) -> None: signature = inspect.signature(dependency_class.build) returned_class = signature.return_annotation diff --git a/backend/infrahub/graphql/initialization.py b/backend/infrahub/graphql/initialization.py index c05aff2436..872a058d0c 100644 --- a/backend/infrahub/graphql/initialization.py +++ b/backend/infrahub/graphql/initialization.py @@ -7,7 +7,9 @@ from infrahub.core import registry from infrahub.core.timestamp import Timestamp +from infrahub.dependencies.registry import get_component_registry from infrahub.exceptions import InitializationError +from infrahub.graphql.resolver import SingleRelationshipResolver from .manager import GraphQLSchemaManager @@ -81,6 +83,10 @@ def prepare_graphql_params( if request and not service: service = request.app.state.service + component_registry = get_component_registry() + srr = SingleRelationshipResolver(db=db) + component_registry.cache_component(srr) + return GraphqlParams( schema=gql_schema, context=GraphqlContext( diff --git a/backend/infrahub/graphql/resolver.py b/backend/infrahub/graphql/resolver.py index 7fb6a75e22..c0485dffb8 100644 --- a/backend/infrahub/graphql/resolver.py +++ b/backend/infrahub/graphql/resolver.py @@ -1,12 +1,17 @@ from __future__ import annotations +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional +from aiodataloader import DataLoader from infrahub_sdk.utils import extract_fields +from infrahub.core.branch.models import Branch from infrahub.core.constants import BranchSupportType, InfrahubKind, RelationshipHierarchyDirection from infrahub.core.manager import NodeManager from infrahub.core.query.node import NodeGetHierarchyQuery +from infrahub.core.timestamp import Timestamp +from infrahub.dependencies.registry import get_component_registry from infrahub.exceptions import NodeNotFoundError from .parser import extract_selection @@ -14,9 +19,13 @@ from .types import RELATIONS_PROPERTY_MAP, RELATIONS_PROPERTY_MAP_REVERSED if TYPE_CHECKING: + from graphql import GraphQLResolveInfo + from infrahub_sdk.schema import RelationshipSchema + from infrahub.core.relationship.model import Relationship from infrahub.core.schema import MainSchemaTypes, NodeSchema + from infrahub.database import InfrahubDatabase from infrahub.graphql.initialization import GraphqlContext @@ -205,42 +214,119 @@ async def default_paginated_list_resolver( return response -async def single_relationship_resolver(parent: dict, info: GraphQLResolveInfo, **kwargs) -> dict[str, Any]: - """Resolver for relationships of cardinality=one for Edged responses +@dataclass +class QueryPeersArgs: + source_kind: str + schema: RelationshipSchema + filters: dict[str, Any] + fields: dict | None = None + at: Timestamp | str | None = None + branch: Branch | str | None = None + branch_agnostic: bool = False + fetch_peers: bool = False + + def __hash__(self) -> int: + frozen_filters: frozenset | None = None + if self.filters: + frozen_filters = to_frozen_set(self.filters) + frozen_fields: frozenset | None = None + if self.fields: + frozen_fields = to_frozen_set(self.fields) + timestamp = Timestamp(self.at) + branch = self.branch.name if isinstance(self.branch, Branch) else self.branch + hash_str = "|".join( + [ + self.source_kind, + self.schema.name, + str(hash(frozen_filters)), + str(hash(frozen_fields)), + timestamp.to_string(), + branch, + str(self.branch_agnostic), + str(self.fetch_peers), + ] + ) + return hash(hash_str) + + +def to_frozen_set(to_freeze: dict[str, Any]) -> frozenset: + freezing_dict = {} + for k, v in to_freeze.items(): + if isinstance(v, dict): + freezing_dict[k] = to_frozen_set(v) + elif isinstance(v, (list, set)): + freezing_dict[k] = frozenset(v) + else: + freezing_dict[k] = v + return frozenset(freezing_dict) + + +class SingleRelationshipResolverDataLoader(DataLoader): + def __init__(self, db: InfrahubDatabase, query_args: QueryPeersArgs, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.query_args = query_args + self.db = db + + async def batch_load_fn(self, keys: list[Any]) -> list[Relationship]: + async with self.db.start_session() as db: + relationships = await NodeManager.query_peers( + db=db, + ids=keys, + source_kind=self.query_args.source_kind, + schema=self.query_args.schema, + filters=self.query_args.filters, + fields=self.query_args.fields, + at=self.query_args.at, + branch=self.query_args.branch, + branch_agnostic=self.query_args.branch_agnostic, + fetch_peers=self.query_args.fetch_peers, + ) + relationships_by_node_id = {r.node_id: r for r in relationships} + results = [] + for node_id in keys: + if node_id in relationships_by_node_id: + results.append(relationships_by_node_id[node_id]) + else: + results.append(None) + return results - This resolver is used for paginated responses and as such we redefined the requested - fields by only reusing information below the 'node' key. - """ - # Extract the InfraHub schema by inspecting the GQL Schema - node_schema: NodeSchema = info.parent_type.graphene_type._meta.schema +class SingleRelationshipResolver: + def __init__(self, db: InfrahubDatabase) -> None: + self.db = db + self._data_loader_instances: dict[QueryPeersArgs, SingleRelationshipResolverDataLoader] = {} - context: GraphqlContext = info.context + async def resolve(self, parent: dict, info: GraphQLResolveInfo, **kwargs) -> dict[str, Any]: + """Resolver for relationships of cardinality=one for Edged responses - # Extract the name of the fields in the GQL query - fields = await extract_fields(info.field_nodes[0].selection_set) - node_fields = fields.get("node", {}) - property_fields = fields.get("properties", {}) - for key, value in property_fields.items(): - mapped_name = RELATIONS_PROPERTY_MAP[key] - node_fields[mapped_name] = value + This resolver is used for paginated responses and as such we redefined the requested + fields by only reusing information below the 'node' key. + """ + # Extract the InfraHub schema by inspecting the GQL Schema - # Extract the schema of the node on the other end of the relationship from the GQL Schema - node_rel = node_schema.get_relationship(info.field_name) + node_schema: NodeSchema = info.parent_type.graphene_type._meta.schema - # Extract only the filters from the kwargs and prepend the name of the field to the filters - filters = { - f"{info.field_name}__{key}": value - for key, value in kwargs.items() - if "__" in key and value or key in ["id", "ids"] - } + context: GraphqlContext = info.context - response: dict[str, Any] = {"node": None, "properties": {}} + # Extract the name of the fields in the GQL query + fields = await extract_fields(info.field_nodes[0].selection_set) + node_fields = fields.get("node", {}) + property_fields = fields.get("properties", {}) + for key, value in property_fields.items(): + mapped_name = RELATIONS_PROPERTY_MAP[key] + node_fields[mapped_name] = value - async with context.db.start_session() as db: - objs = await NodeManager.query_peers( - db=db, - ids=[parent["id"]], + # Extract the schema of the node on the other end of the relationship from the GQL Schema + node_rel = node_schema.get_relationship(info.field_name) + + # Extract only the filters from the kwargs and prepend the name of the field to the filters + filters = { + f"{info.field_name}__{key}": value + for key, value in kwargs.items() + if "__" in key and value or key in ["id", "ids"] + } + + query_args = QueryPeersArgs( source_kind=node_schema.kind, schema=node_rel, filters=filters, @@ -251,10 +337,18 @@ async def single_relationship_resolver(parent: dict, info: GraphQLResolveInfo, * fetch_peers=True, ) - if not objs: + if query_args in self._data_loader_instances: + loader = self._data_loader_instances[query_args] + else: + loader = SingleRelationshipResolverDataLoader(db=context.db, query_args=query_args) + self._data_loader_instances[query_args] = loader + result = await loader.load(key=parent["id"]) + + response: dict[str, Any] = {"node": None, "properties": {}} + if not result: return response - node_graph = await objs[0].to_graphql(db=db, fields=node_fields, related_node_ids=context.related_node_ids) + node_graph = await result.to_graphql(db=self.db, fields=node_fields, related_node_ids=context.related_node_ids) for key, mapped in RELATIONS_PROPERTY_MAP_REVERSED.items(): value = node_graph.pop(key, None) if value: @@ -263,6 +357,12 @@ async def single_relationship_resolver(parent: dict, info: GraphQLResolveInfo, * return response +async def single_relationship_resolver(parent: dict, info: GraphQLResolveInfo, **kwargs) -> dict[str, Any]: + component_registry = get_component_registry() + resolver: SingleRelationshipResolver = component_registry.get_cached_component(SingleRelationshipResolver) + return await resolver.resolve(parent=parent, info=info, **kwargs) + + async def many_relationship_resolver( parent: dict, info: GraphQLResolveInfo, include_descendants: Optional[bool] = False, **kwargs ) -> dict[str, Any]: diff --git a/poetry.lock b/poetry.lock index 9fcd39e69f..43d18d4632 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aio-pika" @@ -15,6 +15,24 @@ files = [ aiormq = ">=6.8.0,<6.9.0" yarl = "*" +[[package]] +name = "aiodataloader" +version = "0.4.0" +description = "Asyncio DataLoader implementation for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiodataloader-0.4.0-py3-none-any.whl", hash = "sha256:2775d8607e1b68ded82efc93c839846d0ae9d9e0421085e444f3c1c541f3c2b6"}, + {file = "aiodataloader-0.4.0.tar.gz", hash = "sha256:6de9ca0eb75c4eef686754679981ebd45a933391b40473f92419b8a747504169"}, +] + +[package.dependencies] +typing-extensions = ">=4.1.1" + +[package.extras] +lint = ["black", "flake8", "flake8-import-order", "mypy"] +test = ["coveralls", "mock", "pytest (>=3.6)", "pytest-asyncio", "pytest-cov"] + [[package]] name = "aiormq" version = "6.8.1" @@ -5800,4 +5818,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10, < 3.13" -content-hash = "7b1ff36c08dd52951d7caa461cbf3803c8c1ce9518648d3b2d9aee452b55edf0" +content-hash = "fa342c818089596f9ba51ddad954ce9092dbbb9116937549a338f905d9ec3b42" diff --git a/pyproject.toml b/pyproject.toml index 20c58855f6..ef9145bc52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ opentelemetry-exporter-otlp-proto-http = "1.28.1" nats-py = "^2.7.2" netaddr = "1.3.0" authlib = "1.3.2" +aiodataloader = "0.4.0" # Dependencies specific to the SDK From 84e2005c5e204f1e8251ff638c4e7215fb112412 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Wed, 11 Dec 2024 19:44:29 -0800 Subject: [PATCH 2/3] format --- backend/infrahub/graphql/resolver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/infrahub/graphql/resolver.py b/backend/infrahub/graphql/resolver.py index c0485dffb8..13a454ad4e 100644 --- a/backend/infrahub/graphql/resolver.py +++ b/backend/infrahub/graphql/resolver.py @@ -19,7 +19,6 @@ from .types import RELATIONS_PROPERTY_MAP, RELATIONS_PROPERTY_MAP_REVERSED if TYPE_CHECKING: - from graphql import GraphQLResolveInfo from infrahub_sdk.schema import RelationshipSchema From 26a928a68be5da7026b2c3c21b238934aa6b9c31 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Thu, 12 Dec 2024 07:18:23 -0800 Subject: [PATCH 3/3] pylint --- backend/infrahub/graphql/resolver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/infrahub/graphql/resolver.py b/backend/infrahub/graphql/resolver.py index 13a454ad4e..86f196b2ad 100644 --- a/backend/infrahub/graphql/resolver.py +++ b/backend/infrahub/graphql/resolver.py @@ -266,7 +266,7 @@ def __init__(self, db: InfrahubDatabase, query_args: QueryPeersArgs, *args: Any, self.query_args = query_args self.db = db - async def batch_load_fn(self, keys: list[Any]) -> list[Relationship]: + async def batch_load_fn(self, keys: list[Any]) -> list[Relationship]: # pylint: disable=method-hidden async with self.db.start_session() as db: relationships = await NodeManager.query_peers( db=db,