Skip to content

Commit

Permalink
Refactor PathFactoryFacade to avoid the if--checks for each algorithm…
Browse files Browse the repository at this point in the history
… individually
  • Loading branch information
IoannisPanagiotas committed Nov 21, 2024
1 parent e45bd74 commit 6c72cfe
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,28 @@
import org.neo4j.gds.api.CloseableResourceRegistry;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.NodeLookup;
import org.neo4j.gds.applications.algorithms.machinery.StreamResultBuilder;
import org.neo4j.gds.paths.PathResult;
import org.neo4j.gds.paths.bellmanford.BellmanFordResult;
import org.neo4j.gds.paths.dijkstra.PathFindingResult;
import org.neo4j.graphdb.RelationshipType;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;

class BellmanFordResultBuilderForStreamMode implements StreamResultBuilder<BellmanFordResult, BellmanFordStreamResult> {
private final CloseableResourceRegistry closeableResourceRegistry;
private final NodeLookup nodeLookup;
private final boolean routeRequested;
private static final String COST_PROPERTY_NAME = "cost";
private static final String RELATIONSHIP_TYPE_TEMPLATE = "PATH_%d";

BellmanFordResultBuilderForStreamMode(
CloseableResourceRegistry closeableResourceRegistry,
Expand All @@ -56,30 +66,52 @@ public Stream<BellmanFordStreamResult> build(
var bellmanFordResult = result.get();

// this is us handling the case of generated graphs and such
var shouldCreateRoutes = routeRequested && graphStore.capabilities().canWriteToLocalDatabase();

var containsNegativeCycle = bellmanFordResult.containsNegativeCycle();

var resultBuilder = new BellmanFordStreamResult.Builder(graph, nodeLookup)
.withIsCycle(containsNegativeCycle);
var pathFactoryFacade = PathFactoryFacade.create(routeRequested, nodeLookup,graphStore);

var algorithmResult = getPathFindingResult(bellmanFordResult, containsNegativeCycle);
var algorithmResult = getPathFindingResult(bellmanFordResult, bellmanFordResult.containsNegativeCycle());

var resultStream = algorithmResult.mapPaths(path -> resultBuilder.build(
path.nodeIds(),
path.costs(),
path.index(),
path.sourceNode(),
path.targetNode(),
path.totalCost(),
shouldCreateRoutes
var resultStream = algorithmResult.mapPaths(route -> mapRoute(
route,
graph,
bellmanFordResult.containsNegativeCycle(),
pathFactoryFacade
));

closeableResourceRegistry.register(resultStream);

return resultStream;
}

private BellmanFordStreamResult mapRoute(PathResult pathResult, IdMap idMap, boolean negativeCycle,PathFactoryFacade pathFactoryFacade){
var nodeIds = pathResult.nodeIds();
for (int i = 0; i < nodeIds.length; i++) {
nodeIds[i] = idMap.toOriginalNodeId(nodeIds[i]);
}
var relationshipType = RelationshipType.withName(formatWithLocale(RELATIONSHIP_TYPE_TEMPLATE, pathResult.index()));

double[] costs = pathResult.costs();

var path = pathFactoryFacade.createPath(
nodeIds,
costs,
relationshipType,
COST_PROPERTY_NAME
);

return new BellmanFordStreamResult(
pathResult.index(),
idMap.toOriginalNodeId(pathResult.sourceNode()),
idMap.toOriginalNodeId(pathResult.targetNode()),
pathResult.totalCost(),
// 😿
Arrays.stream(nodeIds).boxed().collect(Collectors.toCollection(() -> new ArrayList<>(nodeIds.length))),
Arrays.stream(costs).boxed().collect(Collectors.toCollection(() -> new ArrayList<>(costs.length))),
path,
negativeCycle
);
}


private static PathFindingResult getPathFindingResult(
BellmanFordResult result,
boolean containsNegativeCycle
Expand All @@ -89,3 +121,4 @@ private static PathFindingResult getPathFindingResult(
return result.shortestPaths();
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ public Stream<BfsStreamResult> build(
result.get(),
graph::toOriginalNodeId,
BfsStreamResult::new,
pathRequested && graphStore.capabilities().canWriteToLocalDatabase(),
new PathFactoryFacade(),
PathFactoryFacade.create(pathRequested, nodeLookup,graphStore),
RelationshipType.withName(RELATIONSHIP_TYPE_NAME),
nodeLookup
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ public Stream<DfsStreamResult> build(
result.get(),
graph::toOriginalNodeId,
DfsStreamResult::new,
pathRequested && graphStore.capabilities().canWriteToLocalDatabase(),
new PathFactoryFacade(),
PathFactoryFacade.create(pathRequested, nodeLookup,graphStore),
RelationshipType.withName(RELATIONSHIP_TYPE_NAME),
nodeLookup
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/
package org.neo4j.gds.procedures.algorithms.pathfinding;

import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.NodeLookup;
import org.neo4j.gds.paths.PathFactory;
import org.neo4j.graphdb.Path;
Expand All @@ -27,15 +28,44 @@
import java.util.List;

public class PathFactoryFacade {
private final boolean canCreatePaths;
private final NodeLookup nodeLookup;

private PathFactoryFacade(boolean canCreatePaths, NodeLookup nodeLookup) {
this.canCreatePaths = canCreatePaths;
this.nodeLookup = nodeLookup;
}

public static PathFactoryFacade create(boolean pathIsYielded, NodeLookup nodeLookup, GraphStore graphStore){
var canCreatePaths = pathIsYielded && graphStore.capabilities().canWriteToLocalDatabase();
return new PathFactoryFacade(canCreatePaths,nodeLookup);
}

public Path createPath(
NodeLookup nodeLookup,
List<Long> nodeList,
RelationshipType relationshipType
) {
if (!canCreatePaths) return null;
return PathFactory.create(
nodeLookup,
nodeList,
relationshipType
);
}

public Path createPath(
long[] nodeList,
double[] costs,
RelationshipType relationshipType,
String costPropertyName
) {
if (!canCreatePaths) return null;
return PathFactory.create(
nodeLookup,
nodeList,
costs,
relationshipType,
costPropertyName
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.NodeLookup;
import org.neo4j.gds.applications.algorithms.machinery.StreamResultBuilder;
import org.neo4j.gds.paths.PathFactory;
import org.neo4j.gds.paths.PathResult;
import org.neo4j.gds.paths.dijkstra.PathFindingResult;
import org.neo4j.graphdb.Path;
import org.neo4j.graphdb.RelationshipType;

import java.util.ArrayList;
Expand Down Expand Up @@ -60,52 +59,47 @@ public Stream<PathFindingStreamResult> build(
) {
if (result.isEmpty()) return Stream.of();

// this is us handling the case of generated graphs and such
var createCypherPaths = pathRequested && graphStore.capabilities().canWriteToLocalDatabase();

var pathFindingResult = result.get();
var pathFactoryFacade= PathFactoryFacade.create(pathRequested, nodeLookup,graphStore);

var resultStream = pathFindingResult.mapPaths(pathResult -> {
var nodeIds = pathResult.nodeIds();
var costs = pathResult.costs();
var pathIndex = pathResult.index();
var resultStream = pathFindingResult.mapPaths(pathResult -> mapPath(pathResult,graph,pathFactoryFacade));

var relationshipType = RelationshipType.withName(formatWithLocale("PATH_%d", pathIndex));
closeableResourceRegistry.register(resultStream);

// convert internal ids to Neo ids
for (int i = 0; i < nodeIds.length; i++) {
nodeIds[i] = graph.toOriginalNodeId(nodeIds[i]);
}
return resultStream;
}

Path path = null;
if (createCypherPaths) {
path = PathFactory.create(
nodeLookup,
nodeIds,
costs,
relationshipType,
PathFindingStreamResult.COST_PROPERTY_NAME
);
}
PathFindingStreamResult mapPath(PathResult pathResult, Graph graph, PathFactoryFacade pathFactoryFacade){
var nodeIds = pathResult.nodeIds();
var costs = pathResult.costs();
var pathIndex = pathResult.index();

return new PathFindingStreamResult(
pathIndex,
graph.toOriginalNodeId(pathResult.sourceNode()),
graph.toOriginalNodeId(pathResult.targetNode()),
pathResult.totalCost(),
// 😿
Arrays.stream(nodeIds)
.boxed()
.collect(Collectors.toCollection(() -> new ArrayList<>(nodeIds.length))),
Arrays.stream(costs)
.boxed()
.collect(Collectors.toCollection(() -> new ArrayList<>(costs.length))),
path
);
});
var relationshipType = RelationshipType.withName(formatWithLocale("PATH_%d", pathIndex));

closeableResourceRegistry.register(resultStream);
// convert internal ids to Neo ids
for (int i = 0; i < nodeIds.length; i++) {
nodeIds[i] = graph.toOriginalNodeId(nodeIds[i]);
}
var path = pathFactoryFacade.createPath(
nodeIds,
costs,
relationshipType,
PathFindingStreamResult.COST_PROPERTY_NAME
);

return resultStream;
return new PathFindingStreamResult(
pathIndex,
graph.toOriginalNodeId(pathResult.sourceNode()),
graph.toOriginalNodeId(pathResult.targetNode()),
pathResult.totalCost(),
// 😿
Arrays.stream(nodeIds)
.boxed()
.collect(Collectors.toCollection(() -> new ArrayList<>(nodeIds.length))),
Arrays.stream(costs)
.boxed()
.collect(Collectors.toCollection(() -> new ArrayList<>(costs.length))),
path
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,11 @@
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.NodeLookup;
import org.neo4j.gds.applications.algorithms.machinery.StreamResultBuilder;
import org.neo4j.gds.paths.PathFactory;
import org.neo4j.graphdb.Path;
import org.neo4j.graphdb.RelationshipType;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Stream;

import static org.neo4j.gds.procedures.algorithms.pathfinding.RandomWalkStreamResult.RELATIONSHIP_TYPE_NAME;
Expand Down Expand Up @@ -60,17 +57,14 @@ public Stream<RandomWalkStreamResult> build(
) {
if (result.isEmpty()) return Stream.empty();

boolean constructPath = returnPath && graphStore.capabilities().canWriteToLocalDatabase();
Function<List<Long>, Path> pathCreator = constructPath
? (List<Long> nodes) -> PathFactory.create(nodeLookup, nodes, RelationshipType.withName(
RELATIONSHIP_TYPE_NAME))
: (List<Long> nodes) -> null;

var pathFactoryFacade = PathFactoryFacade.create(returnPath,nodeLookup,graphStore);

var streamOfLongArrays = result.get();

var resultStream = streamOfLongArrays.map(nodes -> {
var translatedNodes = translateInternalToNeoIds(nodes, graph);
var path = pathCreator.apply(translatedNodes);
var path = pathFactoryFacade.createPath(translatedNodes, RelationshipType.withName(RELATIONSHIP_TYPE_NAME));
return new RandomWalkStreamResult(translatedNodes, path);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import org.neo4j.gds.api.NodeLookup;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.graphdb.Path;
import org.neo4j.graphdb.RelationshipType;

import java.util.Arrays;
Expand All @@ -37,7 +36,6 @@ public static <T> Stream<T> consume(
HugeLongArray nodes,
LongUnaryOperator toOriginalNodeId,
ConcreteResultTransformer<T> resultTransformer,
boolean shouldReturnPath,
PathFactoryFacade pathFactoryFacade,
RelationshipType relationshipType,
NodeLookup nodeLookup
Expand All @@ -49,14 +47,10 @@ public static <T> Stream<T> consume(
.map(toOriginalNodeId::applyAsLong)
.collect(Collectors.toList());

Path path = null;
if (shouldReturnPath) {
path = pathFactoryFacade.createPath(
nodeLookup,
var path = pathFactoryFacade.createPath(
nodeList,
relationshipType
);
}

return Stream.of(resultTransformer.transform(
sourceNodeId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,22 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyNoInteractions;

class TraverseStreamComputationResultConsumerTest {

@Test
void shouldNotComputePath() {
var pathFactoryFacadeMock = mock(PathFactoryFacade.class);
var pathFactoryFacade = PathFactoryFacade.create(false,null,null);
var result = TraverseStreamComputationResultConsumer.consume(
0L,
HugeLongArray.of(1L, 2L),
l -> l,
TestResult::new,
false,
pathFactoryFacadeMock,
pathFactoryFacade,
RelationshipType.withName("TEST"),
mock(InternalTransaction.class)::getNodeById
);

verifyNoInteractions(pathFactoryFacadeMock);

assertThat(result)
.hasSize(1)
Expand All @@ -63,13 +60,12 @@ void shouldNotComputePath() {
@Test
void shouldComputePath() {
var pathFactoryFacadeMock = mock(PathFactoryFacade.class);
doReturn(mock(Path.class)).when(pathFactoryFacadeMock).createPath(any(), any(), any());
doReturn(mock(Path.class)).when(pathFactoryFacadeMock).createPath(any(), any());
var result = TraverseStreamComputationResultConsumer.consume(
0L,
HugeLongArray.of(1L, 2L),
l -> l,
TestResult::new,
true,
pathFactoryFacadeMock,
RelationshipType.withName("TEST"),
mock(InternalTransaction.class)::getNodeById
Expand Down
Loading

0 comments on commit 6c72cfe

Please sign in to comment.