Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP prototype using DataLoader for single_relationship_resolver #5196

Draft
wants to merge 3 commits into
base: release-1.1
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/infrahub/core/query/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def get_peer_ids(self) -> list[str]:
return [peer.peer_id for peer in self.get_peers()]

def get_peers(self) -> Generator[RelationshipPeerData, None, None]:
for result in self.get_results_group_by(("peer", "uuid")):
for result in self.get_results_group_by(("peer", "uuid"), ("source_node", "uuid")):
rels = result.get("rels")
data = RelationshipPeerData(
source_id=result.get_node("source_node").get("uuid"),
Expand Down
11 changes: 10 additions & 1 deletion backend/infrahub/dependencies/component/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Optional, TypeVar

from ..interface import DependencyBuilderContext
from .exceptions import UntrackedDependencyError
Expand All @@ -21,6 +21,7 @@ class ComponentDependencyRegistry:

def __init__(self) -> None:
self._available_components: dict[type, type[DependencyBuilder]] = {}
self._cached_components: dict[type, Any] = {}

@classmethod
def get_registry(cls) -> ComponentDependencyRegistry:
Expand All @@ -34,6 +35,14 @@ async def get_component(self, component_class: type[T], db: InfrahubDatabase, br
context = DependencyBuilderContext(db=db, branch=branch)
return self._available_components[component_class].build(context=context)

def cache_component(self, component: Any) -> None:
self._cached_components[type(component)] = component

def get_cached_component(self, component_class: type[T]) -> T:
if component_class not in self._cached_components:
raise UntrackedDependencyError(f"'{component_class}' is not a cached component")
return self._cached_components[component_class]

def track_dependency(self, dependency_class: type[DependencyBuilder]) -> None:
signature = inspect.signature(dependency_class.build)
returned_class = signature.return_annotation
Expand Down
6 changes: 6 additions & 0 deletions backend/infrahub/graphql/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from infrahub.core import registry
from infrahub.core.timestamp import Timestamp
from infrahub.dependencies.registry import get_component_registry
from infrahub.exceptions import InitializationError
from infrahub.graphql.resolver import SingleRelationshipResolver

from .manager import GraphQLSchemaManager

Expand Down Expand Up @@ -81,6 +83,10 @@ def prepare_graphql_params(
if request and not service:
service = request.app.state.service

component_registry = get_component_registry()
srr = SingleRelationshipResolver(db=db)
component_registry.cache_component(srr)

return GraphqlParams(
schema=gql_schema,
context=GraphqlContext(
Expand Down
159 changes: 129 additions & 30 deletions backend/infrahub/graphql/resolver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional

from aiodataloader import DataLoader
from infrahub_sdk.utils import extract_fields

from infrahub.core.branch.models import Branch
from infrahub.core.constants import BranchSupportType, InfrahubKind, RelationshipHierarchyDirection
from infrahub.core.manager import NodeManager
from infrahub.core.query.node import NodeGetHierarchyQuery
from infrahub.core.timestamp import Timestamp
from infrahub.dependencies.registry import get_component_registry
from infrahub.exceptions import NodeNotFoundError

from .parser import extract_selection
Expand All @@ -15,8 +20,11 @@

if TYPE_CHECKING:
from graphql import GraphQLResolveInfo
from infrahub_sdk.schema import RelationshipSchema

from infrahub.core.relationship.model import Relationship
from infrahub.core.schema import MainSchemaTypes, NodeSchema
from infrahub.database import InfrahubDatabase
from infrahub.graphql.initialization import GraphqlContext


Expand Down Expand Up @@ -205,42 +213,119 @@ async def default_paginated_list_resolver(
return response


async def single_relationship_resolver(parent: dict, info: GraphQLResolveInfo, **kwargs) -> dict[str, Any]:
"""Resolver for relationships of cardinality=one for Edged responses
@dataclass
class QueryPeersArgs:
source_kind: str
schema: RelationshipSchema
filters: dict[str, Any]
fields: dict | None = None
at: Timestamp | str | None = None
branch: Branch | str | None = None
branch_agnostic: bool = False
fetch_peers: bool = False

def __hash__(self) -> int:
frozen_filters: frozenset | None = None
if self.filters:
frozen_filters = to_frozen_set(self.filters)
frozen_fields: frozenset | None = None
if self.fields:
frozen_fields = to_frozen_set(self.fields)
timestamp = Timestamp(self.at)
branch = self.branch.name if isinstance(self.branch, Branch) else self.branch
hash_str = "|".join(
[
self.source_kind,
self.schema.name,
str(hash(frozen_filters)),
str(hash(frozen_fields)),
timestamp.to_string(),
branch,
str(self.branch_agnostic),
str(self.fetch_peers),
]
)
return hash(hash_str)


def to_frozen_set(to_freeze: dict[str, Any]) -> frozenset:
freezing_dict = {}
for k, v in to_freeze.items():
if isinstance(v, dict):
freezing_dict[k] = to_frozen_set(v)
elif isinstance(v, (list, set)):
freezing_dict[k] = frozenset(v)
else:
freezing_dict[k] = v
return frozenset(freezing_dict)


class SingleRelationshipResolverDataLoader(DataLoader):
def __init__(self, db: InfrahubDatabase, query_args: QueryPeersArgs, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.query_args = query_args
self.db = db

async def batch_load_fn(self, keys: list[Any]) -> list[Relationship]: # pylint: disable=method-hidden
async with self.db.start_session() as db:
relationships = await NodeManager.query_peers(
db=db,
ids=keys,
source_kind=self.query_args.source_kind,
schema=self.query_args.schema,
filters=self.query_args.filters,
fields=self.query_args.fields,
at=self.query_args.at,
branch=self.query_args.branch,
branch_agnostic=self.query_args.branch_agnostic,
fetch_peers=self.query_args.fetch_peers,
)
relationships_by_node_id = {r.node_id: r for r in relationships}
results = []
for node_id in keys:
if node_id in relationships_by_node_id:
results.append(relationships_by_node_id[node_id])
else:
results.append(None)
return results

This resolver is used for paginated responses and as such we redefined the requested
fields by only reusing information below the 'node' key.
"""
# Extract the InfraHub schema by inspecting the GQL Schema

node_schema: NodeSchema = info.parent_type.graphene_type._meta.schema
class SingleRelationshipResolver:
def __init__(self, db: InfrahubDatabase) -> None:
self.db = db
self._data_loader_instances: dict[QueryPeersArgs, SingleRelationshipResolverDataLoader] = {}

context: GraphqlContext = info.context
async def resolve(self, parent: dict, info: GraphQLResolveInfo, **kwargs) -> dict[str, Any]:
"""Resolver for relationships of cardinality=one for Edged responses

# Extract the name of the fields in the GQL query
fields = await extract_fields(info.field_nodes[0].selection_set)
node_fields = fields.get("node", {})
property_fields = fields.get("properties", {})
for key, value in property_fields.items():
mapped_name = RELATIONS_PROPERTY_MAP[key]
node_fields[mapped_name] = value
This resolver is used for paginated responses and as such we redefined the requested
fields by only reusing information below the 'node' key.
"""
# Extract the InfraHub schema by inspecting the GQL Schema

# Extract the schema of the node on the other end of the relationship from the GQL Schema
node_rel = node_schema.get_relationship(info.field_name)
node_schema: NodeSchema = info.parent_type.graphene_type._meta.schema

# Extract only the filters from the kwargs and prepend the name of the field to the filters
filters = {
f"{info.field_name}__{key}": value
for key, value in kwargs.items()
if "__" in key and value or key in ["id", "ids"]
}
context: GraphqlContext = info.context

response: dict[str, Any] = {"node": None, "properties": {}}
# Extract the name of the fields in the GQL query
fields = await extract_fields(info.field_nodes[0].selection_set)
node_fields = fields.get("node", {})
property_fields = fields.get("properties", {})
for key, value in property_fields.items():
mapped_name = RELATIONS_PROPERTY_MAP[key]
node_fields[mapped_name] = value

async with context.db.start_session() as db:
objs = await NodeManager.query_peers(
db=db,
ids=[parent["id"]],
# Extract the schema of the node on the other end of the relationship from the GQL Schema
node_rel = node_schema.get_relationship(info.field_name)

# Extract only the filters from the kwargs and prepend the name of the field to the filters
filters = {
f"{info.field_name}__{key}": value
for key, value in kwargs.items()
if "__" in key and value or key in ["id", "ids"]
}

query_args = QueryPeersArgs(
source_kind=node_schema.kind,
schema=node_rel,
filters=filters,
Expand All @@ -251,10 +336,18 @@ async def single_relationship_resolver(parent: dict, info: GraphQLResolveInfo, *
fetch_peers=True,
)

if not objs:
if query_args in self._data_loader_instances:
loader = self._data_loader_instances[query_args]
else:
loader = SingleRelationshipResolverDataLoader(db=context.db, query_args=query_args)
self._data_loader_instances[query_args] = loader
result = await loader.load(key=parent["id"])

response: dict[str, Any] = {"node": None, "properties": {}}
if not result:
return response

node_graph = await objs[0].to_graphql(db=db, fields=node_fields, related_node_ids=context.related_node_ids)
node_graph = await result.to_graphql(db=self.db, fields=node_fields, related_node_ids=context.related_node_ids)
for key, mapped in RELATIONS_PROPERTY_MAP_REVERSED.items():
value = node_graph.pop(key, None)
if value:
Expand All @@ -263,6 +356,12 @@ async def single_relationship_resolver(parent: dict, info: GraphQLResolveInfo, *
return response


async def single_relationship_resolver(parent: dict, info: GraphQLResolveInfo, **kwargs) -> dict[str, Any]:
component_registry = get_component_registry()
resolver: SingleRelationshipResolver = component_registry.get_cached_component(SingleRelationshipResolver)
return await resolver.resolve(parent=parent, info=info, **kwargs)


async def many_relationship_resolver(
parent: dict, info: GraphQLResolveInfo, include_descendants: Optional[bool] = False, **kwargs
) -> dict[str, Any]:
Expand Down
22 changes: 20 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ opentelemetry-exporter-otlp-proto-http = "1.28.1"
nats-py = "^2.7.2"
netaddr = "1.3.0"
authlib = "1.3.2"
aiodataloader = "0.4.0"


# Dependencies specific to the SDK
Expand Down
Loading