From 9a53a9835685b2710aebac034e59a66eb24b1e6c Mon Sep 17 00:00:00 2001 From: Veselin Nikolov Date: Tue, 26 Nov 2024 15:15:38 +0000 Subject: [PATCH] Register requested concurrency before running algorithm Co-authored-by: Ioannis Panagiotas --- .../centrality/CentralityAlgorithms.java | 63 ++++++++-- .../community/CommunityAlgorithms.java | 98 ++++++++++++--- .../MachineLearningAlgorithms.java | 9 +- .../machinery/AlgorithmMachinery.java | 5 +- .../machinery/AlgorithmMachineryTest.java | 45 ++++--- .../MiscellaneousAlgorithms.java | 21 +++- .../embeddings/NodeEmbeddingAlgorithms.java | 35 +++++- .../pathfinding/PathFindingAlgorithms.java | 115 ++++++++++++++---- .../traverse/BreadthFirstSearch.java | 7 +- .../traverse/DepthFirstSearch.java | 7 +- .../similarity/SimilarityAlgorithms.java | 28 ++++- .../NodeClassificationTrainComputation.java | 7 +- .../NodeRegressionTrainComputation.java | 7 +- 13 files changed, 370 insertions(+), 77 deletions(-) diff --git a/applications/algorithms/centrality/src/main/java/org/neo4j/gds/applications/algorithms/centrality/CentralityAlgorithms.java b/applications/algorithms/centrality/src/main/java/org/neo4j/gds/applications/algorithms/centrality/CentralityAlgorithms.java index be1aae56e0..7216cf4605 100644 --- a/applications/algorithms/centrality/src/main/java/org/neo4j/gds/applications/algorithms/centrality/CentralityAlgorithms.java +++ b/applications/algorithms/centrality/src/main/java/org/neo4j/gds/applications/algorithms/centrality/CentralityAlgorithms.java @@ -125,7 +125,12 @@ BitSet articulationPoints(Graph graph, AlgoBaseConfig configuration) { var algorithm = new ArticulationPoints(graph, progressTracker); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } BetwennessCentralityResult betweennessCentrality(Graph graph, BetweennessCentralityBaseConfig configuration) { @@ -168,7 +173,12 @@ public BetwennessCentralityResult betweennessCentrality( terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + parameters.concurrency() + ); } BridgeResult bridges(Graph graph, AlgoBaseConfig configuration) { @@ -178,7 +188,12 @@ BridgeResult bridges(Graph graph, AlgoBaseConfig configuration) { var algorithm = new Bridges(graph, progressTracker); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } public CELFResult celf(Graph graph, InfluenceMaximizationBaseConfig configuration) { @@ -191,7 +206,12 @@ public CELFResult celf(Graph graph, InfluenceMaximizationBaseConfig configuratio var algorithm = new CELF(graph, configuration.toParameters(), DefaultPool.INSTANCE, progressTracker); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } ClosenessCentralityResult closenessCentrality(Graph graph, ClosenessCentralityBaseConfig configuration) { @@ -218,7 +238,12 @@ ClosenessCentralityResult closenessCentrality(Graph graph, ClosenessCentralityBa terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } DegreeCentralityResult degreeCentrality(Graph graph, DegreeCentralityConfig configuration) { @@ -237,7 +262,12 @@ DegreeCentralityResult degreeCentrality(Graph graph, DegreeCentralityConfig conf progressTracker ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } PageRankResult eigenVector(Graph graph, EigenvectorConfig configuration) { @@ -279,7 +309,12 @@ HarmonicResult harmonicCentrality(Graph graph, AlgoBaseConfig configuration) { terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } PregelResult hits(Graph graph, HitsConfig configuration) { @@ -297,7 +332,12 @@ PregelResult hits(Graph graph, HitsConfig configuration) { progressTracker ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } IndirectExposureResult indirectExposure(Graph graph, IndirectExposureConfig configuration) { @@ -315,7 +355,12 @@ IndirectExposureResult indirectExposure(Graph graph, IndirectExposureConfig conf progressTracker ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } public PageRankResult pageRank(Graph graph, PageRankConfig configuration) { diff --git a/applications/algorithms/community/src/main/java/org/neo4j/gds/applications/algorithms/community/CommunityAlgorithms.java b/applications/algorithms/community/src/main/java/org/neo4j/gds/applications/algorithms/community/CommunityAlgorithms.java index 626ecdb01c..4598e68ff1 100644 --- a/applications/algorithms/community/src/main/java/org/neo4j/gds/applications/algorithms/community/CommunityAlgorithms.java +++ b/applications/algorithms/community/src/main/java/org/neo4j/gds/applications/algorithms/community/CommunityAlgorithms.java @@ -120,7 +120,12 @@ ApproxMaxKCutResult approximateMaximumKCut(Graph graph, ApproxMaxKCutBaseConfig terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } ConductanceResult conductance(Graph graph, ConductanceBaseConfig configuration) { @@ -144,7 +149,12 @@ ConductanceResult conductance(Graph graph, ConductanceBaseConfig configuration) progressTracker ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } K1ColoringResult k1Coloring(Graph graph, K1ColoringBaseConfig configuration) { @@ -164,7 +174,12 @@ K1ColoringResult k1Coloring(Graph graph, K1ColoringBaseConfig configuration) { terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } KCoreDecompositionResult kCore(Graph graph, AlgoBaseConfig configuration) { @@ -173,7 +188,12 @@ KCoreDecompositionResult kCore(Graph graph, AlgoBaseConfig configuration) { var algorithm = new KCoreDecomposition(graph, configuration.concurrency(), progressTracker, terminationFlag); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } public KmeansResult kMeans(Graph graph, KmeansBaseConfig configuration) { @@ -194,7 +214,12 @@ public KmeansResult kMeans(Graph graph, KmeansBaseConfig configuration) { .build(); var algorithm = Kmeans.createKmeans(graph, configuration.toParameters(), kmeansContext, terminationFlag); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } LabelPropagationResult labelPropagation(Graph graph, LabelPropagationBaseConfig configuration) { @@ -217,7 +242,12 @@ LabelPropagationResult labelPropagation(Graph graph, LabelPropagationBaseConfig terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } LocalClusteringCoefficientResult lcc(Graph graph, LocalClusteringCoefficientBaseConfig configuration) { @@ -240,7 +270,12 @@ LocalClusteringCoefficientResult lcc(Graph graph, LocalClusteringCoefficientBase terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } LeidenResult leiden(Graph graph, LeidenBaseConfig configuration) { @@ -284,7 +319,12 @@ LeidenResult leiden(Graph graph, LeidenBaseConfig configuration) { terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } LouvainResult louvain(Graph graph, LouvainBaseConfig configuration) { @@ -307,7 +347,12 @@ LouvainResult louvain(Graph graph, LouvainBaseConfig configuration) { terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } ModularityResult modularity(Graph graph, ModularityBaseConfig configuration) { @@ -350,7 +395,12 @@ ModularityOptimizationResult modularityOptimization(Graph graph, ModularityOptim terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } HugeLongArray scc(Graph graph, SccCommonBaseConfig configuration) { @@ -361,7 +411,12 @@ HugeLongArray scc(Graph graph, SccCommonBaseConfig configuration) { var algorithm = new Scc(graph, progressTracker, terminationFlag); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } TriangleCountResult triangleCount(Graph graph, TriangleCountBaseConfig configuration) { @@ -379,7 +434,12 @@ TriangleCountResult triangleCount(Graph graph, TriangleCountBaseConfig configura terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } Stream triangles(Graph graph, ConcurrencyConfig configuration) { @@ -411,7 +471,12 @@ DisjointSetStruct wcc(Graph graph, WccBaseConfig configuration) { terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } private Task constructKMeansProgressTask(Graph graph, KmeansBaseConfig configuration) { @@ -482,6 +547,11 @@ PregelResult speakerListenerLPA(Graph graph, SpeakerListenerLPAConfig configurat Optional.empty() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } } diff --git a/applications/algorithms/machine-learning/src/main/java/org/neo4j/gds/applications/algorithms/machinelearning/MachineLearningAlgorithms.java b/applications/algorithms/machine-learning/src/main/java/org/neo4j/gds/applications/algorithms/machinelearning/MachineLearningAlgorithms.java index bd4211e2fc..9bcc656b0f 100644 --- a/applications/algorithms/machine-learning/src/main/java/org/neo4j/gds/applications/algorithms/machinelearning/MachineLearningAlgorithms.java +++ b/applications/algorithms/machine-learning/src/main/java/org/neo4j/gds/applications/algorithms/machinelearning/MachineLearningAlgorithms.java @@ -21,8 +21,8 @@ import com.carrotsearch.hppc.BitSet; import org.neo4j.gds.algorithms.machinelearning.KGEPredictBaseConfig; -import org.neo4j.gds.algorithms.machinelearning.KGEPredictResult; import org.neo4j.gds.algorithms.machinelearning.KGEPredictConfigTransformer; +import org.neo4j.gds.algorithms.machinelearning.KGEPredictResult; import org.neo4j.gds.algorithms.machinelearning.TopKMapComputer; import org.neo4j.gds.api.Graph; import org.neo4j.gds.api.GraphStore; @@ -79,7 +79,12 @@ KGEPredictResult kge(Graph graph, KGEPredictBaseConfig configuration) { terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } EdgeSplitter.SplitResult splitRelationships(GraphStore graphStore, SplitRelationshipsBaseConfig configuration) { diff --git a/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachinery.java b/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachinery.java index f23ba6f17c..02a953f4fe 100644 --- a/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachinery.java +++ b/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachinery.java @@ -20,6 +20,7 @@ package org.neo4j.gds.applications.algorithms.machinery; import org.neo4j.gds.Algorithm; +import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; /** @@ -40,9 +41,11 @@ public class AlgorithmMachinery { public RESULT runAlgorithmsAndManageProgressTracker( Algorithm algorithm, ProgressTracker progressTracker, - boolean shouldReleaseProgressTracker + boolean shouldReleaseProgressTracker, + Concurrency concurrency ) { try { + progressTracker.requestedConcurrency(concurrency); return algorithm.compute(); } catch (Exception e) { progressTracker.endSubTaskWithFailure(); diff --git a/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachineryTest.java b/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachineryTest.java index e917f01771..053657f485 100644 --- a/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachineryTest.java +++ b/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachineryTest.java @@ -20,36 +20,47 @@ package org.neo4j.gds.applications.algorithms.machinery; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; import org.neo4j.gds.Algorithm; +import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +@ExtendWith(MockitoExtension.class) class AlgorithmMachineryTest { + private static final Concurrency CONCURRENCY = new Concurrency(4); + @Mock + private Algorithm algo; + @Test void shouldRunAlgorithm() { var algorithmMachinery = new AlgorithmMachinery(); var progressTracker = mock(ProgressTracker.class); - var algo = mock(Algorithm.class); when(algo.compute()).thenReturn("Hello, world!"); var result = algorithmMachinery.runAlgorithmsAndManageProgressTracker( algo, progressTracker, - false + false, + CONCURRENCY ); assertThat(result).isEqualTo("Hello, world!"); - verifyNoInteractions(progressTracker); + verify(progressTracker, times(1)).requestedConcurrency(CONCURRENCY); + verifyNoMoreInteractions(progressTracker); } @Test @@ -58,18 +69,20 @@ void shouldReleaseProgressTrackerWhenAsked() { var progressTracker = mock(ProgressTracker.class); - var algo = mock(Algorithm.class); when(algo.compute()).thenReturn("Dodgers win world series!"); var result = algorithmMachinery.runAlgorithmsAndManageProgressTracker( algo, progressTracker, - true + true, + CONCURRENCY ); assertThat(result).isEqualTo("Dodgers win world series!"); - verify(progressTracker).release(); + verify(progressTracker, times(1)).requestedConcurrency(CONCURRENCY); + verify(progressTracker, times(1)).release(); + verifyNoMoreInteractions(progressTracker); } @Test @@ -79,21 +92,23 @@ void shouldMarkProgressTracker() { var progressTracker = mock(ProgressTracker.class); var exception = new RuntimeException("Whoops!"); - var algo = mock(Algorithm.class); when(algo.compute()).thenThrow(exception); try { algorithmMachinery.runAlgorithmsAndManageProgressTracker( algo, progressTracker, - false + false, + CONCURRENCY ); fail(); } catch (Exception e) { assertThat(e).hasMessage("Whoops!"); } - verify(progressTracker).endSubTaskWithFailure(); + verify(progressTracker, times(1)).requestedConcurrency(CONCURRENCY); + verify(progressTracker, times(1)).endSubTaskWithFailure(); + verifyNoMoreInteractions(progressTracker); } @Test @@ -103,21 +118,23 @@ void shouldMarkProgressTrackerAndReleaseIt() { var progressTracker = mock(ProgressTracker.class); var exception = new RuntimeException("Yeah, no..."); - var algo = mock(Algorithm.class); when(algo.compute()).thenThrow(exception); try { algorithmMachinery.runAlgorithmsAndManageProgressTracker( algo, progressTracker, - true + true, + CONCURRENCY ); fail(); } catch (Exception e) { assertThat(e).hasMessage("Yeah, no..."); } - verify(progressTracker).endSubTaskWithFailure(); - verify(progressTracker).release(); + verify(progressTracker, times(1)).requestedConcurrency(CONCURRENCY); + verify(progressTracker, times(1)).endSubTaskWithFailure(); + verify(progressTracker, times(1)).release(); + verifyNoMoreInteractions(progressTracker); } } diff --git a/applications/algorithms/miscellaneous-algorithms/src/main/java/org/neo4j/gds/applications/algorithms/miscellaneous/MiscellaneousAlgorithms.java b/applications/algorithms/miscellaneous-algorithms/src/main/java/org/neo4j/gds/applications/algorithms/miscellaneous/MiscellaneousAlgorithms.java index 5fa26e83d4..3aa30cb16b 100644 --- a/applications/algorithms/miscellaneous-algorithms/src/main/java/org/neo4j/gds/applications/algorithms/miscellaneous/MiscellaneousAlgorithms.java +++ b/applications/algorithms/miscellaneous-algorithms/src/main/java/org/neo4j/gds/applications/algorithms/miscellaneous/MiscellaneousAlgorithms.java @@ -113,7 +113,12 @@ Map indexInverse( terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } ScalePropertiesResult scaleProperties(Graph graph, ScalePropertiesBaseConfig configuration) { @@ -137,7 +142,12 @@ ScalePropertiesResult scaleProperties(Graph graph, ScalePropertiesBaseConfig con DefaultPool.INSTANCE ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } SingleTypeRelationships toUndirected(GraphStore graphStore, ToUndirectedConfig configuration) { @@ -156,6 +166,11 @@ SingleTypeRelationships toUndirected(GraphStore graphStore, ToUndirectedConfig c terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } } diff --git a/applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithms.java b/applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithms.java index 4c109858ec..e1620e839d 100644 --- a/applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithms.java +++ b/applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithms.java @@ -99,7 +99,12 @@ public FastRPResult fastRP(Graph graph, FastRPBaseConfig configuration, Progress terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } GraphSageResult graphSage(Graph graph, GraphSageBaseConfig configuration) { @@ -119,7 +124,12 @@ GraphSageResult graphSage(Graph graph, GraphSageBaseConfig configuration) { terminationFlag ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } Model graphSageTrain( @@ -146,7 +156,12 @@ Model randomWalk(Graph graph, RandomWalkBaseConfig configuration) { @@ -245,22 +267,33 @@ Stream randomWalk(Graph graph, RandomWalkBaseConfig configuration) { requestScopedDependencies.getTerminationFlag() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + false, + configuration.concurrency() + ); } PrizeSteinerTreeResult pcst(Graph graph, PCSTBaseConfig configuration) { - var task = PCSTProgressTrackerTaskCreator.progressTask(graph.nodeCount(),graph.relationshipCount()); + var task = PCSTProgressTrackerTaskCreator.progressTask(graph.nodeCount(), graph.relationshipCount()); var progressTracker = createProgressTracker(configuration, task); var prizeProperty = graph.nodeProperties(configuration.prizeProperty()); var algorithm = new PCSTFast( graph, - (v) -> Math.max(prizeProperty.longValue(v),0), + (v) -> Math.max(prizeProperty.longValue(v), 0), progressTracker ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } + HugeAtomicLongArray randomWalkCountingNodeVisits(Graph graph, RandomWalkBaseConfig configuration) { var tasks = new ArrayList(); if (graph.hasRelationshipProperty()) { @@ -286,7 +319,12 @@ HugeAtomicLongArray randomWalkCountingNodeVisits(Graph graph, RandomWalkBaseConf requestScopedDependencies.getTerminationFlag() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } public PathFindingResult singlePairShortestPathAStar(Graph graph, ShortestPathAStarBaseConfig configuration) { @@ -302,7 +340,12 @@ public PathFindingResult singlePairShortestPathAStar(Graph graph, ShortestPathAS requestScopedDependencies.getTerminationFlag() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + false, + configuration.concurrency() + ); } /** @@ -327,7 +370,12 @@ PathFindingResult singlePairShortestPathDijkstra(Graph graph, DijkstraSourceTarg requestScopedDependencies.getTerminationFlag() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + false, + configuration.concurrency() + ); } public PathFindingResult singlePairShortestPathYens(Graph graph, ShortestPathYensBaseConfig configuration) { @@ -356,7 +404,12 @@ public PathFindingResult singlePairShortestPathYens( requestScopedDependencies.getTerminationFlag() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + false, + configuration.concurrency() + ); } PathFindingResult singleSourceShortestPathDijkstra(Graph graph, DijkstraBaseConfig configuration) { @@ -374,7 +427,12 @@ PathFindingResult singleSourceShortestPathDijkstra(Graph graph, DijkstraBaseConf requestScopedDependencies.getTerminationFlag() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + false, + configuration.concurrency() + ); } public SpanningTree spanningTree(Graph graph, SpanningTreeBaseConfig configuration) { @@ -397,7 +455,12 @@ public SpanningTree spanningTree(Graph graph, SpanningTreeBaseConfig configurati requestScopedDependencies.getTerminationFlag() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } public SteinerTreeResult steinerTree(Graph graph, SteinerTreeBaseConfig configuration) { @@ -430,7 +493,12 @@ public SteinerTreeResult steinerTree(Graph graph, SteinerTreeBaseConfig configur requestScopedDependencies.getTerminationFlag() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } public TopologicalSortResult topologicalSort(Graph graph, TopologicalSortBaseConfig configuration) { @@ -450,7 +518,12 @@ public TopologicalSortResult topologicalSort(Graph graph, TopologicalSortBaseCon requestScopedDependencies.getTerminationFlag() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } private MSBFSASPAlgorithm selectAlgorithm(Graph graph, AllShortestPathsConfig configuration) { diff --git a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/BreadthFirstSearch.java b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/BreadthFirstSearch.java index e106739346..f31e1f00ed 100644 --- a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/BreadthFirstSearch.java +++ b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/BreadthFirstSearch.java @@ -68,6 +68,11 @@ public HugeLongArray compute(Graph graph, BfsBaseConfig configuration, ProgressT terminationFlag ); - return new AlgorithmMachinery().runAlgorithmsAndManageProgressTracker(bfs, progressTracker, true); + return new AlgorithmMachinery().runAlgorithmsAndManageProgressTracker( + bfs, + progressTracker, + true, + configuration.concurrency() + ); } } diff --git a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/DepthFirstSearch.java b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/DepthFirstSearch.java index b6d4c9b3dc..2f08bfbca1 100644 --- a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/DepthFirstSearch.java +++ b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/DepthFirstSearch.java @@ -67,6 +67,11 @@ public HugeLongArray compute(Graph graph, DfsBaseConfig configuration, ProgressT terminationFlag ); - return new AlgorithmMachinery().runAlgorithmsAndManageProgressTracker(dfs, progressTracker, true); + return new AlgorithmMachinery().runAlgorithmsAndManageProgressTracker( + dfs, + progressTracker, + true, + configuration.concurrency() + ); } } diff --git a/applications/algorithms/similarity/src/main/java/org/neo4j/gds/applications/algorithms/similarity/SimilarityAlgorithms.java b/applications/algorithms/similarity/src/main/java/org/neo4j/gds/applications/algorithms/similarity/SimilarityAlgorithms.java index b8564fc836..caa5e53b52 100644 --- a/applications/algorithms/similarity/src/main/java/org/neo4j/gds/applications/algorithms/similarity/SimilarityAlgorithms.java +++ b/applications/algorithms/similarity/src/main/java/org/neo4j/gds/applications/algorithms/similarity/SimilarityAlgorithms.java @@ -92,7 +92,12 @@ FilteredKnnResult filteredKnn(Graph graph, FilteredKnnBaseConfig configuration) var algorithm = selectAlgorithmConfiguration(graph, configuration, knnContext); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } NodeSimilarityResult filteredNodeSimilarity(Graph graph, FilteredNodeSimilarityBaseConfig configuration) { @@ -118,7 +123,12 @@ NodeSimilarityResult filteredNodeSimilarity(Graph graph, FilteredNodeSimilarityB requestScopedDependencies.getTerminationFlag() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } KnnResult knn(Graph graph, KnnBaseConfig configuration) { @@ -156,7 +166,12 @@ KnnResult knn(Graph graph, KnnBaseConfig configuration) { requestScopedDependencies.getTerminationFlag() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } NodeSimilarityResult nodeSimilarity(Graph graph, NodeSimilarityBaseConfig configuration) { @@ -185,7 +200,12 @@ NodeSimilarityResult nodeSimilarity(Graph graph, NodeSimilarityBaseConfig config requestScopedDependencies.getTerminationFlag() ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } private Task filteredNodeSimilarityProgressTask(Graph graph, boolean runWcc) { diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationTrainComputation.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationTrainComputation.java index 7e8fdd9314..f8b88c2cab 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationTrainComputation.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationTrainComputation.java @@ -204,6 +204,11 @@ public NodeClassificationModelResult compute(Graph graph, GraphStore graphStore) progressTracker ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } } diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeRegressionTrainComputation.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeRegressionTrainComputation.java index 0235f2ce40..d5e6778eb7 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeRegressionTrainComputation.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeRegressionTrainComputation.java @@ -204,6 +204,11 @@ public NodeRegressionTrainResult.NodeRegressionTrainPipelineResult compute(Graph progressTracker ); - return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker( + algorithm, + progressTracker, + true, + configuration.concurrency() + ); } }