diff --git a/algo/src/main/java/org/neo4j/gds/pagerank/PageRankAlgorithmFactory.java b/algo/src/main/java/org/neo4j/gds/pagerank/PageRankAlgorithmFactory.java index d62d9c6d84..6ead19c2b5 100644 --- a/algo/src/main/java/org/neo4j/gds/pagerank/PageRankAlgorithmFactory.java +++ b/algo/src/main/java/org/neo4j/gds/pagerank/PageRankAlgorithmFactory.java @@ -44,9 +44,6 @@ public class PageRankAlgorithmFactory extends GraphAlgorithmFactory { - static Task pagerankProgressTask(Graph graph, CONFIG config) { - return Pregel.progressTask(graph, config, "PageRank"); - } private static double averageDegree(Graph graph, int concurrency) { var degreeSum = new LongAdder(); @@ -60,9 +57,19 @@ private static double averageDegree(Graph graph, int concurrency) { } public enum Mode { - PAGE_RANK, - ARTICLE_RANK, - EIGENVECTOR, + PAGE_RANK("PageRank"), + ARTICLE_RANK("ArticleRank"), + EIGENVECTOR("EigenVector"); + + private final String taskName; + + Mode(String taskName) { + this.taskName = taskName; + } + + String taskName() { + return taskName; + } } private final Mode mode; @@ -77,7 +84,7 @@ public PageRankAlgorithmFactory(Mode mode) { @Override public String taskName() { - return mode.name(); + return mode.taskName(); } @Override @@ -128,7 +135,7 @@ public PageRankAlgorithm build( @Override public Task progressTask(Graph graph, CONFIG config) { - return pagerankProgressTask(graph, config); + return Pregel.progressTask(graph, config, taskName()); } @NotNull diff --git a/algo/src/test/java/org/neo4j/gds/pagerank/PageRankTest.java b/algo/src/test/java/org/neo4j/gds/pagerank/PageRankTest.java index 5db2c5025b..e7b63ab09d 100644 --- a/algo/src/test/java/org/neo4j/gds/pagerank/PageRankTest.java +++ b/algo/src/test/java/org/neo4j/gds/pagerank/PageRankTest.java @@ -169,15 +169,19 @@ void withSourceNodes(String sourceNodesString, String expectedPropertyKey) { } } - @Test - void shouldLogProgress() { + @ParameterizedTest + @EnumSource(Mode.class) + void shouldLogProgress(Mode mode) { var maxIterations = 10; var config = ImmutablePageRankConfig.builder() .maxIterations(maxIterations) .build(); - var progressTask = PageRankAlgorithmFactory.pagerankProgressTask(graph, config); + var factory = new PageRankAlgorithmFactory<>(mode); + + var progressTask = factory.progressTask(graph, config); var log = Neo4jProxy.testLog(); + var progressTracker = new TestProgressTracker( progressTask, log, @@ -185,7 +189,12 @@ void shouldLogProgress() { EmptyTaskRegistryFactory.INSTANCE ); - runOnPregel(graph, config, Mode.PAGE_RANK, progressTracker); + factory.build( + graph, + config, + progressTracker + ) + .compute(); var progresses = progressTracker.getProgresses().stream() .filter(it -> it.get() > 0) @@ -207,22 +216,26 @@ void shouldLogProgress() { .extracting(removingThreadId()) .contains( formatWithLocale( - "PageRank :: Compute iteration %d of %d :: Start", + "%s :: Compute iteration %d of %d :: Start", + mode.taskName(), iteration, config.maxIterations() ), formatWithLocale( - "PageRank :: Compute iteration %d of %d :: Finished", + "%s :: Compute iteration %d of %d :: Finished", + mode.taskName(), iteration, config.maxIterations() ), formatWithLocale( - "PageRank :: Master compute iteration %d of %d :: Start", + "%s :: Master compute iteration %d of %d :: Start", + mode.taskName(), iteration, config.maxIterations() ), formatWithLocale( - "PageRank :: Master compute iteration %d of %d :: Finished", + "%s :: Master compute iteration %d of %d :: Finished", + mode.taskName(), iteration, config.maxIterations() ) @@ -231,8 +244,8 @@ void shouldLogProgress() { assertThat(messages) .extracting(removingThreadId()) .contains( - "PageRank :: Start", - "PageRank :: Finished" + formatWithLocale("%s :: Start", mode.taskName()), + formatWithLocale("%s :: Finished", mode.taskName()) ); } @@ -681,16 +694,13 @@ PageRankResult runOnPregel(Graph graph, PageRankConfig config) { } PageRankResult runOnPregel(Graph graph, PageRankConfig config, Mode mode) { - return runOnPregel(graph, config, mode, ProgressTracker.NULL_TRACKER); - } - - PageRankResult runOnPregel(Graph graph, PageRankConfig config, Mode mode, ProgressTracker progressTracker) { return new PageRankAlgorithmFactory<>(mode) .build( graph, config, - progressTracker + ProgressTracker.NULL_TRACKER ) .compute(); } + }