Skip to content

Commit

Permalink
Merge pull request #840 from neo4j-contrib/improved_intermediate_tran…
Browse files Browse the repository at this point in the history
…sform

Improved intermediate_transform method.
  • Loading branch information
mariusconjeaud authored Nov 22, 2024
2 parents 68ee553 + a13697f commit f866109
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 81 deletions.
3 changes: 3 additions & 0 deletions Changelog
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
Version 5.4.1 2024-11
* Add options for intermediate_transform : distinct, include_in_return, use a prop as source

Version 5.4.0 2024-11
* Traversal option for filtering and ordering
* Insert raw Cypher for ordering
Expand Down
37 changes: 34 additions & 3 deletions doc/source/advanced_query_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,41 @@ As discussed in the note above, this is for example useful when you need to orde
# This will return all Coffee nodes, with their most expensive supplier
Coffee.nodes.traverse_relations(suppliers="suppliers")
.intermediate_transform(
{"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"]
{"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"]
)
.annotate(supps=Last(Collect("suppliers")))

Options for `intermediate_transform` *variables* are:

- `source`: `string`or `Resolver` - the variable to use as source for the transformation. Works with resolvers (see below).
- `source_prop`: `string` - optionally, a property of the source variable to use as source for the transformation.
- `include_in_return`: `bool` - whether to include the variable in the return statement. Defaults to False.

Additional options for the `intermediate_transform` method are:
- `distinct`: `bool` - whether to deduplicate the results. Defaults to False.

Here is a full example::

await Coffee.nodes.fetch_relations("suppliers")
.intermediate_transform(
{
"coffee": "coffee",
"suppliers": NodeNameResolver("suppliers"),
"r": RelationNameResolver("suppliers"),
"coffee": {"source": "coffee", "include_in_return": True}, # Only coffee will be returned
"suppliers": {"source": NodeNameResolver("suppliers")},
"r": {"source": RelationNameResolver("suppliers")},
"cost": {
"source": NodeNameResolver("suppliers"),
"source_prop": "delivery_cost",
},
},
distinct=True,
ordering=["-r.since"],
)
.annotate(oldest_supplier=Last(Collect("suppliers")))
.all()

Subqueries
----------

Expand All @@ -71,7 +102,7 @@ The `subquery` method allows you to perform a `Cypher subquery <https://neo4j.co
.subquery(
Coffee.nodes.traverse_relations(suppliers="suppliers")
.intermediate_transform(
{"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"]
{"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"]
)
.annotate(supps=Last(Collect("suppliers"))),
["supps"],
Expand Down Expand Up @@ -108,4 +139,4 @@ In some cases though, it is not possible to set explicit aliases, for example wh

.. note::

When using the resolvers in combination with a traversal as in the example above, it will resolve the variable name of the last element in the traversal - the Species node for NodeNameResolver, and Coffee--Species relationship for RelationshipNameResolver.
When using the resolvers in combination with a traversal as in the example above, it will resolve the variable name of the last element in the traversal - the Species node for NodeNameResolver, and Coffee--Species relationship for RelationshipNameResolver.
74 changes: 42 additions & 32 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from neomodel.exceptions import MultipleNodesReturned
from neomodel.match_q import Q, QBase
from neomodel.properties import AliasProperty, ArrayProperty, Property
from neomodel.typing import Transformation
from neomodel.util import INCOMING, OUTGOING

CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)")
Expand Down Expand Up @@ -838,32 +839,27 @@ def build_query(self) -> str:
query += " WITH "
query += self._ast.with_clause

returned_items: list[str] = []
if hasattr(self.node_set, "_intermediate_transforms"):
for transform in self.node_set._intermediate_transforms:
query += " WITH "
query += "DISTINCT " if transform.get("distinct") else ""
injected_vars: list = []
# Reset return list since we'll probably invalidate most variables
self._ast.return_clause = ""
self._ast.additional_return = []
for name, source in transform["vars"].items():
if type(source) is str:
injected_vars.append(f"{source} AS {name}")
elif isinstance(source, RelationNameResolver):
result = self.lookup_query_variable(
source.relation, return_relation=True
)
if not result:
raise ValueError(
f"Unable to resolve variable name for relation {source.relation}."
)
injected_vars.append(f"{result[0]} AS {name}")
elif isinstance(source, NodeNameResolver):
result = self.lookup_query_variable(source.node)
if not result:
raise ValueError(
f"Unable to resolve variable name for node {source.node}."
)
injected_vars.append(f"{result[0]} AS {name}")
for name, varprops in transform["vars"].items():
source = varprops["source"]
if isinstance(source, (NodeNameResolver, RelationNameResolver)):
transformation = source.resolve(self)
else:
transformation = source
if varprops.get("source_prop"):
transformation += f".{varprops['source_prop']}"
transformation += f" AS {name}"
if varprops.get("include_in_return"):
returned_items += [name]
injected_vars.append(transformation)
query += ",".join(injected_vars)
if not transform["ordering"]:
continue
Expand All @@ -879,7 +875,6 @@ def build_query(self) -> str:
ordering.append(item)
query += ",".join(ordering)

returned_items: list[str] = []
if hasattr(self.node_set, "_subqueries"):
for subquery, return_set in self.node_set._subqueries:
outer_primary_var = self._ast.return_clause
Expand Down Expand Up @@ -1098,6 +1093,14 @@ class RelationNameResolver:

relation: str

def resolve(self, qbuilder: AsyncQueryBuilder) -> str:
result = qbuilder.lookup_query_variable(self.relation, True)
if result is None:
raise ValueError(

Check warning on line 1099 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L1099

Added line #L1099 was not covered by tests
f"Unable to resolve variable name for relation {self.relation}"
)
return result[0]


@dataclass
class NodeNameResolver:
Expand All @@ -1111,6 +1114,12 @@ class NodeNameResolver:

node: str

def resolve(self, qbuilder: AsyncQueryBuilder) -> str:
result = qbuilder.lookup_query_variable(self.node)
if result is None:
raise ValueError(f"Unable to resolve variable name for node {self.node}")

Check warning on line 1120 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L1120

Added line #L1120 was not covered by tests
return result[0]


@dataclass
class BaseFunction:
Expand All @@ -1123,15 +1132,10 @@ def get_internal_name(self) -> str:
return self._internal_name

def resolve_internal_name(self, qbuilder: AsyncQueryBuilder) -> str:
if isinstance(self.input_name, NodeNameResolver):
result = qbuilder.lookup_query_variable(self.input_name.node)
elif isinstance(self.input_name, RelationNameResolver):
result = qbuilder.lookup_query_variable(self.input_name.relation, True)
if isinstance(self.input_name, (NodeNameResolver, RelationNameResolver)):
self._internal_name = self.input_name.resolve(qbuilder)
else:
result = (str(self.input_name), None)
if result is None:
raise ValueError(f"Unknown variable {self.input_name} used in Collect()")
self._internal_name = result[0]
self._internal_name = str(self.input_name)
return self._internal_name

def render(self, qbuilder: AsyncQueryBuilder) -> str:
Expand Down Expand Up @@ -1538,20 +1542,26 @@ async def subquery(
return self

def intermediate_transform(
self, vars: Dict[str, Any], ordering: TOptional[list] = None
self,
vars: Dict[str, Transformation],
distinct: bool = False,
ordering: TOptional[list] = None,
) -> "AsyncNodeSet":
if not vars:
raise ValueError(
"You must provide one variable at least when calling intermediate_transform()"
)
for name, source in vars.items():
for name, props in vars.items():
source = props["source"]
if type(source) is not str and not isinstance(
source, (NodeNameResolver, RelationNameResolver)
source, (NodeNameResolver, RelationNameResolver, RawCypher)
):
raise ValueError(
f"Wrong source type specified for variable '{name}', should be a string or an instance of NodeNameResolver or RelationNameResolver"
)
self._intermediate_transforms.append({"vars": vars, "ordering": ordering})
self._intermediate_transforms.append(
{"vars": vars, "distinct": distinct, "ordering": ordering}
)
return self


Expand Down
74 changes: 42 additions & 32 deletions neomodel/sync_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from neomodel.sync_ import relationship_manager
from neomodel.sync_.core import StructuredNode, db
from neomodel.sync_.relationship import StructuredRel
from neomodel.typing import Transformation
from neomodel.util import INCOMING, OUTGOING

CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)")
Expand Down Expand Up @@ -840,32 +841,27 @@ def build_query(self) -> str:
query += " WITH "
query += self._ast.with_clause

returned_items: list[str] = []
if hasattr(self.node_set, "_intermediate_transforms"):
for transform in self.node_set._intermediate_transforms:
query += " WITH "
query += "DISTINCT " if transform.get("distinct") else ""
injected_vars: list = []
# Reset return list since we'll probably invalidate most variables
self._ast.return_clause = ""
self._ast.additional_return = []
for name, source in transform["vars"].items():
if type(source) is str:
injected_vars.append(f"{source} AS {name}")
elif isinstance(source, RelationNameResolver):
result = self.lookup_query_variable(
source.relation, return_relation=True
)
if not result:
raise ValueError(
f"Unable to resolve variable name for relation {source.relation}."
)
injected_vars.append(f"{result[0]} AS {name}")
elif isinstance(source, NodeNameResolver):
result = self.lookup_query_variable(source.node)
if not result:
raise ValueError(
f"Unable to resolve variable name for node {source.node}."
)
injected_vars.append(f"{result[0]} AS {name}")
for name, varprops in transform["vars"].items():
source = varprops["source"]
if isinstance(source, (NodeNameResolver, RelationNameResolver)):
transformation = source.resolve(self)
else:
transformation = source
if varprops.get("source_prop"):
transformation += f".{varprops['source_prop']}"
transformation += f" AS {name}"
if varprops.get("include_in_return"):
returned_items += [name]
injected_vars.append(transformation)
query += ",".join(injected_vars)
if not transform["ordering"]:
continue
Expand All @@ -881,7 +877,6 @@ def build_query(self) -> str:
ordering.append(item)
query += ",".join(ordering)

returned_items: list[str] = []
if hasattr(self.node_set, "_subqueries"):
for subquery, return_set in self.node_set._subqueries:
outer_primary_var = self._ast.return_clause
Expand Down Expand Up @@ -1098,6 +1093,14 @@ class RelationNameResolver:

relation: str

def resolve(self, qbuilder: QueryBuilder) -> str:
result = qbuilder.lookup_query_variable(self.relation, True)
if result is None:
raise ValueError(

Check warning on line 1099 in neomodel/sync_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/sync_/match.py#L1099

Added line #L1099 was not covered by tests
f"Unable to resolve variable name for relation {self.relation}"
)
return result[0]


@dataclass
class NodeNameResolver:
Expand All @@ -1111,6 +1114,12 @@ class NodeNameResolver:

node: str

def resolve(self, qbuilder: QueryBuilder) -> str:
result = qbuilder.lookup_query_variable(self.node)
if result is None:
raise ValueError(f"Unable to resolve variable name for node {self.node}")

Check warning on line 1120 in neomodel/sync_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/sync_/match.py#L1120

Added line #L1120 was not covered by tests
return result[0]


@dataclass
class BaseFunction:
Expand All @@ -1123,15 +1132,10 @@ def get_internal_name(self) -> str:
return self._internal_name

def resolve_internal_name(self, qbuilder: QueryBuilder) -> str:
if isinstance(self.input_name, NodeNameResolver):
result = qbuilder.lookup_query_variable(self.input_name.node)
elif isinstance(self.input_name, RelationNameResolver):
result = qbuilder.lookup_query_variable(self.input_name.relation, True)
if isinstance(self.input_name, (NodeNameResolver, RelationNameResolver)):
self._internal_name = self.input_name.resolve(qbuilder)
else:
result = (str(self.input_name), None)
if result is None:
raise ValueError(f"Unknown variable {self.input_name} used in Collect()")
self._internal_name = result[0]
self._internal_name = str(self.input_name)
return self._internal_name

def render(self, qbuilder: QueryBuilder) -> str:
Expand Down Expand Up @@ -1536,20 +1540,26 @@ def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet":
return self

def intermediate_transform(
self, vars: Dict[str, Any], ordering: TOptional[list] = None
self,
vars: Dict[str, Transformation],
distinct: bool = False,
ordering: TOptional[list] = None,
) -> "NodeSet":
if not vars:
raise ValueError(
"You must provide one variable at least when calling intermediate_transform()"
)
for name, source in vars.items():
for name, props in vars.items():
source = props["source"]
if type(source) is not str and not isinstance(
source, (NodeNameResolver, RelationNameResolver)
source, (NodeNameResolver, RelationNameResolver, RawCypher)
):
raise ValueError(
f"Wrong source type specified for variable '{name}', should be a string or an instance of NodeNameResolver or RelationNameResolver"
)
self._intermediate_transforms.append({"vars": vars, "ordering": ordering})
self._intermediate_transforms.append(
{"vars": vars, "distinct": distinct, "ordering": ordering}
)
return self


Expand Down
12 changes: 12 additions & 0 deletions neomodel/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Custom types used for annotations."""

from typing import Any, Optional, TypedDict

Transformation = TypedDict(
"Transformation",
{
"source": Any,
"source_prop": Optional[str],
"include_in_return": Optional[bool],
},
)
Loading

0 comments on commit f866109

Please sign in to comment.