diff --git a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/BellmanFordResultBuilderForStreamMode.java b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/BellmanFordResultBuilderForStreamMode.java index 9c41139395..dabb168219 100644 --- a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/BellmanFordResultBuilderForStreamMode.java +++ b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/BellmanFordResultBuilderForStreamMode.java @@ -22,18 +22,28 @@ import org.neo4j.gds.api.CloseableResourceRegistry; import org.neo4j.gds.api.Graph; import org.neo4j.gds.api.GraphStore; +import org.neo4j.gds.api.IdMap; import org.neo4j.gds.api.NodeLookup; import org.neo4j.gds.applications.algorithms.machinery.StreamResultBuilder; +import org.neo4j.gds.paths.PathResult; import org.neo4j.gds.paths.bellmanford.BellmanFordResult; import org.neo4j.gds.paths.dijkstra.PathFindingResult; +import org.neo4j.graphdb.RelationshipType; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Optional; +import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; + class BellmanFordResultBuilderForStreamMode implements StreamResultBuilder { private final CloseableResourceRegistry closeableResourceRegistry; private final NodeLookup nodeLookup; private final boolean routeRequested; + private static final String COST_PROPERTY_NAME = "cost"; + private static final String RELATIONSHIP_TYPE_TEMPLATE = "PATH_%d"; BellmanFordResultBuilderForStreamMode( CloseableResourceRegistry closeableResourceRegistry, @@ -56,23 +66,15 @@ public Stream build( var bellmanFordResult = result.get(); // this is us handling the case of generated graphs and such - var shouldCreateRoutes = routeRequested && graphStore.capabilities().canWriteToLocalDatabase(); - - var containsNegativeCycle = bellmanFordResult.containsNegativeCycle(); - - var resultBuilder = new BellmanFordStreamResult.Builder(graph, nodeLookup) - .withIsCycle(containsNegativeCycle); + var pathFactoryFacade = PathFactoryFacade.create(routeRequested, nodeLookup,graphStore); - var algorithmResult = getPathFindingResult(bellmanFordResult, containsNegativeCycle); + var algorithmResult = getPathFindingResult(bellmanFordResult, bellmanFordResult.containsNegativeCycle()); - var resultStream = algorithmResult.mapPaths(path -> resultBuilder.build( - path.nodeIds(), - path.costs(), - path.index(), - path.sourceNode(), - path.targetNode(), - path.totalCost(), - shouldCreateRoutes + var resultStream = algorithmResult.mapPaths(route -> mapRoute( + route, + graph, + bellmanFordResult.containsNegativeCycle(), + pathFactoryFacade )); closeableResourceRegistry.register(resultStream); @@ -80,6 +82,36 @@ public Stream build( return resultStream; } + private BellmanFordStreamResult mapRoute(PathResult pathResult, IdMap idMap, boolean negativeCycle,PathFactoryFacade pathFactoryFacade){ + var nodeIds = pathResult.nodeIds(); + for (int i = 0; i < nodeIds.length; i++) { + nodeIds[i] = idMap.toOriginalNodeId(nodeIds[i]); + } + var relationshipType = RelationshipType.withName(formatWithLocale(RELATIONSHIP_TYPE_TEMPLATE, pathResult.index())); + + double[] costs = pathResult.costs(); + + var path = pathFactoryFacade.createPath( + nodeIds, + costs, + relationshipType, + COST_PROPERTY_NAME + ); + + return new BellmanFordStreamResult( + pathResult.index(), + idMap.toOriginalNodeId(pathResult.sourceNode()), + idMap.toOriginalNodeId(pathResult.targetNode()), + pathResult.totalCost(), + // 😿 + Arrays.stream(nodeIds).boxed().collect(Collectors.toCollection(() -> new ArrayList<>(nodeIds.length))), + Arrays.stream(costs).boxed().collect(Collectors.toCollection(() -> new ArrayList<>(costs.length))), + path, + negativeCycle + ); + } + + private static PathFindingResult getPathFindingResult( BellmanFordResult result, boolean containsNegativeCycle @@ -89,3 +121,4 @@ private static PathFindingResult getPathFindingResult( return result.shortestPaths(); } } + diff --git a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/BfsStreamResultBuilder.java b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/BfsStreamResultBuilder.java index 6b3bbc2a38..1b5dec217f 100644 --- a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/BfsStreamResultBuilder.java +++ b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/BfsStreamResultBuilder.java @@ -57,8 +57,7 @@ public Stream build( result.get(), graph::toOriginalNodeId, BfsStreamResult::new, - pathRequested && graphStore.capabilities().canWriteToLocalDatabase(), - new PathFactoryFacade(), + PathFactoryFacade.create(pathRequested, nodeLookup,graphStore), RelationshipType.withName(RELATIONSHIP_TYPE_NAME), nodeLookup ); diff --git a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/DfsStreamResultBuilder.java b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/DfsStreamResultBuilder.java index 9e1caaac98..c47754188e 100644 --- a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/DfsStreamResultBuilder.java +++ b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/DfsStreamResultBuilder.java @@ -57,8 +57,7 @@ public Stream build( result.get(), graph::toOriginalNodeId, DfsStreamResult::new, - pathRequested && graphStore.capabilities().canWriteToLocalDatabase(), - new PathFactoryFacade(), + PathFactoryFacade.create(pathRequested, nodeLookup,graphStore), RelationshipType.withName(RELATIONSHIP_TYPE_NAME), nodeLookup ); diff --git a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFactoryFacade.java b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFactoryFacade.java index a6b641e411..f851ce588b 100644 --- a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFactoryFacade.java +++ b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFactoryFacade.java @@ -19,6 +19,7 @@ */ package org.neo4j.gds.procedures.algorithms.pathfinding; +import org.neo4j.gds.api.GraphStore; import org.neo4j.gds.api.NodeLookup; import org.neo4j.gds.paths.PathFactory; import org.neo4j.graphdb.Path; @@ -27,15 +28,44 @@ import java.util.List; public class PathFactoryFacade { + private final boolean canCreatePaths; + private final NodeLookup nodeLookup; + + private PathFactoryFacade(boolean canCreatePaths, NodeLookup nodeLookup) { + this.canCreatePaths = canCreatePaths; + this.nodeLookup = nodeLookup; + } + + public static PathFactoryFacade create(boolean pathIsYielded, NodeLookup nodeLookup, GraphStore graphStore){ + var canCreatePaths = pathIsYielded && graphStore.capabilities().canWriteToLocalDatabase(); + return new PathFactoryFacade(canCreatePaths,nodeLookup); + } + public Path createPath( - NodeLookup nodeLookup, List nodeList, RelationshipType relationshipType ) { + if (!canCreatePaths) return null; return PathFactory.create( nodeLookup, nodeList, relationshipType ); } + + public Path createPath( + long[] nodeList, + double[] costs, + RelationshipType relationshipType, + String costPropertyName + ) { + if (!canCreatePaths) return null; + return PathFactory.create( + nodeLookup, + nodeList, + costs, + relationshipType, + costPropertyName + ); + } } diff --git a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFindingResultBuilderForStreamMode.java b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFindingResultBuilderForStreamMode.java index 96565042af..d36ab953b3 100644 --- a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFindingResultBuilderForStreamMode.java +++ b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFindingResultBuilderForStreamMode.java @@ -24,9 +24,8 @@ import org.neo4j.gds.api.GraphStore; import org.neo4j.gds.api.NodeLookup; import org.neo4j.gds.applications.algorithms.machinery.StreamResultBuilder; -import org.neo4j.gds.paths.PathFactory; +import org.neo4j.gds.paths.PathResult; import org.neo4j.gds.paths.dijkstra.PathFindingResult; -import org.neo4j.graphdb.Path; import org.neo4j.graphdb.RelationshipType; import java.util.ArrayList; @@ -60,52 +59,47 @@ public Stream build( ) { if (result.isEmpty()) return Stream.of(); - // this is us handling the case of generated graphs and such - var createCypherPaths = pathRequested && graphStore.capabilities().canWriteToLocalDatabase(); - var pathFindingResult = result.get(); + var pathFactoryFacade= PathFactoryFacade.create(pathRequested, nodeLookup,graphStore); - var resultStream = pathFindingResult.mapPaths(pathResult -> { - var nodeIds = pathResult.nodeIds(); - var costs = pathResult.costs(); - var pathIndex = pathResult.index(); + var resultStream = pathFindingResult.mapPaths(pathResult -> mapPath(pathResult,graph,pathFactoryFacade)); - var relationshipType = RelationshipType.withName(formatWithLocale("PATH_%d", pathIndex)); + closeableResourceRegistry.register(resultStream); - // convert internal ids to Neo ids - for (int i = 0; i < nodeIds.length; i++) { - nodeIds[i] = graph.toOriginalNodeId(nodeIds[i]); - } + return resultStream; + } - Path path = null; - if (createCypherPaths) { - path = PathFactory.create( - nodeLookup, - nodeIds, - costs, - relationshipType, - PathFindingStreamResult.COST_PROPERTY_NAME - ); - } + PathFindingStreamResult mapPath(PathResult pathResult, Graph graph, PathFactoryFacade pathFactoryFacade){ + var nodeIds = pathResult.nodeIds(); + var costs = pathResult.costs(); + var pathIndex = pathResult.index(); - return new PathFindingStreamResult( - pathIndex, - graph.toOriginalNodeId(pathResult.sourceNode()), - graph.toOriginalNodeId(pathResult.targetNode()), - pathResult.totalCost(), - // 😿 - Arrays.stream(nodeIds) - .boxed() - .collect(Collectors.toCollection(() -> new ArrayList<>(nodeIds.length))), - Arrays.stream(costs) - .boxed() - .collect(Collectors.toCollection(() -> new ArrayList<>(costs.length))), - path - ); - }); + var relationshipType = RelationshipType.withName(formatWithLocale("PATH_%d", pathIndex)); - closeableResourceRegistry.register(resultStream); + // convert internal ids to Neo ids + for (int i = 0; i < nodeIds.length; i++) { + nodeIds[i] = graph.toOriginalNodeId(nodeIds[i]); + } + var path = pathFactoryFacade.createPath( + nodeIds, + costs, + relationshipType, + PathFindingStreamResult.COST_PROPERTY_NAME + ); - return resultStream; + return new PathFindingStreamResult( + pathIndex, + graph.toOriginalNodeId(pathResult.sourceNode()), + graph.toOriginalNodeId(pathResult.targetNode()), + pathResult.totalCost(), + // 😿 + Arrays.stream(nodeIds) + .boxed() + .collect(Collectors.toCollection(() -> new ArrayList<>(nodeIds.length))), + Arrays.stream(costs) + .boxed() + .collect(Collectors.toCollection(() -> new ArrayList<>(costs.length))), + path + ); } } diff --git a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/RandomWalkResultBuilderForStreamMode.java b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/RandomWalkResultBuilderForStreamMode.java index 923409591a..fda8c02e42 100644 --- a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/RandomWalkResultBuilderForStreamMode.java +++ b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/RandomWalkResultBuilderForStreamMode.java @@ -25,14 +25,11 @@ import org.neo4j.gds.api.IdMap; import org.neo4j.gds.api.NodeLookup; import org.neo4j.gds.applications.algorithms.machinery.StreamResultBuilder; -import org.neo4j.gds.paths.PathFactory; -import org.neo4j.graphdb.Path; import org.neo4j.graphdb.RelationshipType; import java.util.ArrayList; import java.util.List; import java.util.Optional; -import java.util.function.Function; import java.util.stream.Stream; import static org.neo4j.gds.procedures.algorithms.pathfinding.RandomWalkStreamResult.RELATIONSHIP_TYPE_NAME; @@ -60,17 +57,14 @@ public Stream build( ) { if (result.isEmpty()) return Stream.empty(); - boolean constructPath = returnPath && graphStore.capabilities().canWriteToLocalDatabase(); - Function, Path> pathCreator = constructPath - ? (List nodes) -> PathFactory.create(nodeLookup, nodes, RelationshipType.withName( - RELATIONSHIP_TYPE_NAME)) - : (List nodes) -> null; + + var pathFactoryFacade = PathFactoryFacade.create(returnPath,nodeLookup,graphStore); var streamOfLongArrays = result.get(); var resultStream = streamOfLongArrays.map(nodes -> { var translatedNodes = translateInternalToNeoIds(nodes, graph); - var path = pathCreator.apply(translatedNodes); + var path = pathFactoryFacade.createPath(translatedNodes, RelationshipType.withName(RELATIONSHIP_TYPE_NAME)); return new RandomWalkStreamResult(translatedNodes, path); }); diff --git a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/TraverseStreamComputationResultConsumer.java b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/TraverseStreamComputationResultConsumer.java index c9414afb76..59d4d0bdf4 100644 --- a/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/TraverseStreamComputationResultConsumer.java +++ b/procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/TraverseStreamComputationResultConsumer.java @@ -21,7 +21,6 @@ import org.neo4j.gds.api.NodeLookup; import org.neo4j.gds.collections.ha.HugeLongArray; -import org.neo4j.graphdb.Path; import org.neo4j.graphdb.RelationshipType; import java.util.Arrays; @@ -37,7 +36,6 @@ public static Stream consume( HugeLongArray nodes, LongUnaryOperator toOriginalNodeId, ConcreteResultTransformer resultTransformer, - boolean shouldReturnPath, PathFactoryFacade pathFactoryFacade, RelationshipType relationshipType, NodeLookup nodeLookup @@ -49,14 +47,10 @@ public static Stream consume( .map(toOriginalNodeId::applyAsLong) .collect(Collectors.toList()); - Path path = null; - if (shouldReturnPath) { - path = pathFactoryFacade.createPath( - nodeLookup, + var path = pathFactoryFacade.createPath( nodeList, relationshipType ); - } return Stream.of(resultTransformer.transform( sourceNodeId, diff --git a/procedures/algorithms-facade/src/test/java/org/neo4j/gds/procedures/algorithms/pathfinding/TraverseStreamComputationResultConsumerTest.java b/procedures/algorithms-facade/src/test/java/org/neo4j/gds/procedures/algorithms/pathfinding/TraverseStreamComputationResultConsumerTest.java index 4cc2ae2a69..24daea90b1 100644 --- a/procedures/algorithms-facade/src/test/java/org/neo4j/gds/procedures/algorithms/pathfinding/TraverseStreamComputationResultConsumerTest.java +++ b/procedures/algorithms-facade/src/test/java/org/neo4j/gds/procedures/algorithms/pathfinding/TraverseStreamComputationResultConsumerTest.java @@ -31,25 +31,22 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verifyNoInteractions; class TraverseStreamComputationResultConsumerTest { @Test void shouldNotComputePath() { - var pathFactoryFacadeMock = mock(PathFactoryFacade.class); + var pathFactoryFacade = PathFactoryFacade.create(false,null,null); var result = TraverseStreamComputationResultConsumer.consume( 0L, HugeLongArray.of(1L, 2L), l -> l, TestResult::new, - false, - pathFactoryFacadeMock, + pathFactoryFacade, RelationshipType.withName("TEST"), mock(InternalTransaction.class)::getNodeById ); - verifyNoInteractions(pathFactoryFacadeMock); assertThat(result) .hasSize(1) @@ -63,13 +60,12 @@ void shouldNotComputePath() { @Test void shouldComputePath() { var pathFactoryFacadeMock = mock(PathFactoryFacade.class); - doReturn(mock(Path.class)).when(pathFactoryFacadeMock).createPath(any(), any(), any()); + doReturn(mock(Path.class)).when(pathFactoryFacadeMock).createPath(any(), any()); var result = TraverseStreamComputationResultConsumer.consume( 0L, HugeLongArray.of(1L, 2L), l -> l, TestResult::new, - true, pathFactoryFacadeMock, RelationshipType.withName("TEST"), mock(InternalTransaction.class)::getNodeById diff --git a/procedures/facade-api/path-finding-facade-api/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/BellmanFordStreamResult.java b/procedures/facade-api/path-finding-facade-api/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/BellmanFordStreamResult.java index d724ba0d4b..4563aeaca9 100644 --- a/procedures/facade-api/path-finding-facade-api/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/BellmanFordStreamResult.java +++ b/procedures/facade-api/path-finding-facade-api/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/BellmanFordStreamResult.java @@ -19,14 +19,9 @@ */ package org.neo4j.gds.procedures.algorithms.pathfinding; -import org.neo4j.gds.api.IdMap; -import org.neo4j.gds.api.NodeLookup; import org.neo4j.graphdb.Path; -import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; public final class BellmanFordStreamResult { @@ -66,48 +61,6 @@ public BellmanFordStreamResult( this.isNegativeCycle = isNegativeCycle; } - public static class Builder { - private final IdMap idMap; - private final NodeLookup nodeLookup; - private boolean isCycle; - public Builder(IdMap idMap, NodeLookup nodeLookup) { - this.idMap = idMap; - this.nodeLookup = nodeLookup; - } - public Builder withIsCycle(boolean isCycle) { - this.isCycle = isCycle; - return this; - } - - public BellmanFordStreamResult build(long[] nodeIds, double[] costs, long pathIndex, long sourceNode, long targetNode, double totalCost, boolean createCypherPath) { - // convert internal ids to Neo ids - for (int i = 0; i < nodeIds.length; i++) { - nodeIds[i] = idMap.toOriginalNodeId(nodeIds[i]); - } - - Path path = null; - if (createCypherPath) { - path = StandardStreamPathCreator.create( - nodeLookup, - nodeIds, - costs, - pathIndex - ); - } - - return new BellmanFordStreamResult( - pathIndex, - idMap.toOriginalNodeId(sourceNode), - idMap.toOriginalNodeId(targetNode), - totalCost, - // 😿 - Arrays.stream(nodeIds).boxed().collect(Collectors.toCollection(() -> new ArrayList<>(nodeIds.length))), - Arrays.stream(costs).boxed().collect(Collectors.toCollection(() -> new ArrayList<>(costs.length))), - path, - isCycle - ); - } - } } diff --git a/procedures/facade-api/path-finding-facade-api/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFindingStreamResult.java b/procedures/facade-api/path-finding-facade-api/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFindingStreamResult.java index 6d7bb23a67..b1ad409f0b 100644 --- a/procedures/facade-api/path-finding-facade-api/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFindingStreamResult.java +++ b/procedures/facade-api/path-finding-facade-api/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFindingStreamResult.java @@ -59,48 +59,4 @@ public PathFindingStreamResult( this.path = path; } -// public static class Builder { -// private final IdMap idMap; -// private final NodeLookup nodeLookup; -// -// public Builder(IdMap idMap, NodeLookup nodeLookup) { -// this.idMap = idMap; -// this.nodeLookup = nodeLookup; -// } -// -// public PathFindingStreamResult build(PathResult pathResult, boolean createCypherPath) { -// var nodeIds = pathResult.nodeIds(); -// var costs = pathResult.costs(); -// var pathIndex = pathResult.index(); -// -// -// // convert internal ids to Neo ids -// for (int i = 0; i < nodeIds.length; i++) { -// nodeIds[i] = idMap.toOriginalNodeId(nodeIds[i]); -// } -// -// Path path = null; -// if (createCypherPath) { -// path = StandardStreamPathCreator.create( -// nodeLookup, -// nodeIds, -// costs, -// pathIndex -// ); -// -// -// } -// -// return new PathFindingStreamResult( -// pathIndex, -// idMap.toOriginalNodeId(pathResult.sourceNode()), -// idMap.toOriginalNodeId(pathResult.targetNode()), -// pathResult.totalCost(), -// // 😿 -// Arrays.stream(nodeIds).boxed().collect(Collectors.toCollection(() -> new ArrayList<>(nodeIds.length))), -// Arrays.stream(costs).boxed().collect(Collectors.toCollection(() -> new ArrayList<>(costs.length))), -// path -// ); -// } -// } }