Skip to content

Commit

Permalink
Register requested concurrency before running algorithm
Browse files Browse the repository at this point in the history
Co-authored-by: Ioannis Panagiotas <[email protected]>
  • Loading branch information
vnickolov and IoannisPanagiotas committed Nov 27, 2024
1 parent 84ea16a commit 9a53a98
Show file tree
Hide file tree
Showing 13 changed files with 370 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,12 @@ BitSet articulationPoints(Graph graph, AlgoBaseConfig configuration) {

var algorithm = new ArticulationPoints(graph, progressTracker);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

BetwennessCentralityResult betweennessCentrality(Graph graph, BetweennessCentralityBaseConfig configuration) {
Expand Down Expand Up @@ -168,7 +173,12 @@ public BetwennessCentralityResult betweennessCentrality(
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
parameters.concurrency()
);
}

BridgeResult bridges(Graph graph, AlgoBaseConfig configuration) {
Expand All @@ -178,7 +188,12 @@ BridgeResult bridges(Graph graph, AlgoBaseConfig configuration) {

var algorithm = new Bridges(graph, progressTracker);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

public CELFResult celf(Graph graph, InfluenceMaximizationBaseConfig configuration) {
Expand All @@ -191,7 +206,12 @@ public CELFResult celf(Graph graph, InfluenceMaximizationBaseConfig configuratio

var algorithm = new CELF(graph, configuration.toParameters(), DefaultPool.INSTANCE, progressTracker);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

ClosenessCentralityResult closenessCentrality(Graph graph, ClosenessCentralityBaseConfig configuration) {
Expand All @@ -218,7 +238,12 @@ ClosenessCentralityResult closenessCentrality(Graph graph, ClosenessCentralityBa
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

DegreeCentralityResult degreeCentrality(Graph graph, DegreeCentralityConfig configuration) {
Expand All @@ -237,7 +262,12 @@ DegreeCentralityResult degreeCentrality(Graph graph, DegreeCentralityConfig conf
progressTracker
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

PageRankResult eigenVector(Graph graph, EigenvectorConfig configuration) {
Expand Down Expand Up @@ -279,7 +309,12 @@ HarmonicResult harmonicCentrality(Graph graph, AlgoBaseConfig configuration) {
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

PregelResult hits(Graph graph, HitsConfig configuration) {
Expand All @@ -297,7 +332,12 @@ PregelResult hits(Graph graph, HitsConfig configuration) {
progressTracker
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

IndirectExposureResult indirectExposure(Graph graph, IndirectExposureConfig configuration) {
Expand All @@ -315,7 +355,12 @@ IndirectExposureResult indirectExposure(Graph graph, IndirectExposureConfig conf
progressTracker
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

public PageRankResult pageRank(Graph graph, PageRankConfig configuration) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ ApproxMaxKCutResult approximateMaximumKCut(Graph graph, ApproxMaxKCutBaseConfig
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

ConductanceResult conductance(Graph graph, ConductanceBaseConfig configuration) {
Expand All @@ -144,7 +149,12 @@ ConductanceResult conductance(Graph graph, ConductanceBaseConfig configuration)
progressTracker
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

K1ColoringResult k1Coloring(Graph graph, K1ColoringBaseConfig configuration) {
Expand All @@ -164,7 +174,12 @@ K1ColoringResult k1Coloring(Graph graph, K1ColoringBaseConfig configuration) {
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

KCoreDecompositionResult kCore(Graph graph, AlgoBaseConfig configuration) {
Expand All @@ -173,7 +188,12 @@ KCoreDecompositionResult kCore(Graph graph, AlgoBaseConfig configuration) {

var algorithm = new KCoreDecomposition(graph, configuration.concurrency(), progressTracker, terminationFlag);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

public KmeansResult kMeans(Graph graph, KmeansBaseConfig configuration) {
Expand All @@ -194,7 +214,12 @@ public KmeansResult kMeans(Graph graph, KmeansBaseConfig configuration) {
.build();
var algorithm = Kmeans.createKmeans(graph, configuration.toParameters(), kmeansContext, terminationFlag);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

LabelPropagationResult labelPropagation(Graph graph, LabelPropagationBaseConfig configuration) {
Expand All @@ -217,7 +242,12 @@ LabelPropagationResult labelPropagation(Graph graph, LabelPropagationBaseConfig
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

LocalClusteringCoefficientResult lcc(Graph graph, LocalClusteringCoefficientBaseConfig configuration) {
Expand All @@ -240,7 +270,12 @@ LocalClusteringCoefficientResult lcc(Graph graph, LocalClusteringCoefficientBase
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

LeidenResult leiden(Graph graph, LeidenBaseConfig configuration) {
Expand Down Expand Up @@ -284,7 +319,12 @@ LeidenResult leiden(Graph graph, LeidenBaseConfig configuration) {
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

LouvainResult louvain(Graph graph, LouvainBaseConfig configuration) {
Expand All @@ -307,7 +347,12 @@ LouvainResult louvain(Graph graph, LouvainBaseConfig configuration) {
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

ModularityResult modularity(Graph graph, ModularityBaseConfig configuration) {
Expand Down Expand Up @@ -350,7 +395,12 @@ ModularityOptimizationResult modularityOptimization(Graph graph, ModularityOptim
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

HugeLongArray scc(Graph graph, SccCommonBaseConfig configuration) {
Expand All @@ -361,7 +411,12 @@ HugeLongArray scc(Graph graph, SccCommonBaseConfig configuration) {

var algorithm = new Scc(graph, progressTracker, terminationFlag);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

TriangleCountResult triangleCount(Graph graph, TriangleCountBaseConfig configuration) {
Expand All @@ -379,7 +434,12 @@ TriangleCountResult triangleCount(Graph graph, TriangleCountBaseConfig configura
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

Stream<TriangleResult> triangles(Graph graph, ConcurrencyConfig configuration) {
Expand Down Expand Up @@ -411,7 +471,12 @@ DisjointSetStruct wcc(Graph graph, WccBaseConfig configuration) {
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

private Task constructKMeansProgressTask(Graph graph, KmeansBaseConfig configuration) {
Expand Down Expand Up @@ -482,6 +547,11 @@ PregelResult speakerListenerLPA(Graph graph, SpeakerListenerLPAConfig configurat
Optional.empty()
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import com.carrotsearch.hppc.BitSet;
import org.neo4j.gds.algorithms.machinelearning.KGEPredictBaseConfig;
import org.neo4j.gds.algorithms.machinelearning.KGEPredictResult;
import org.neo4j.gds.algorithms.machinelearning.KGEPredictConfigTransformer;
import org.neo4j.gds.algorithms.machinelearning.KGEPredictResult;
import org.neo4j.gds.algorithms.machinelearning.TopKMapComputer;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
Expand Down Expand Up @@ -79,7 +79,12 @@ KGEPredictResult kge(Graph graph, KGEPredictBaseConfig configuration) {
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
progressTracker,
true,
configuration.concurrency()
);
}

EdgeSplitter.SplitResult splitRelationships(GraphStore graphStore, SplitRelationshipsBaseConfig configuration) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.neo4j.gds.applications.algorithms.machinery;

import org.neo4j.gds.Algorithm;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;

/**
Expand All @@ -40,9 +41,11 @@ public class AlgorithmMachinery {
public <RESULT> RESULT runAlgorithmsAndManageProgressTracker(
Algorithm<RESULT> algorithm,
ProgressTracker progressTracker,
boolean shouldReleaseProgressTracker
boolean shouldReleaseProgressTracker,
Concurrency concurrency
) {
try {
progressTracker.requestedConcurrency(concurrency);
return algorithm.compute();
} catch (Exception e) {
progressTracker.endSubTaskWithFailure();
Expand Down
Loading

0 comments on commit 9a53a98

Please sign in to comment.