From 889bb84df274e72ba00ae201b3988e1a50231928 Mon Sep 17 00:00:00 2001 From: Brian Shi Date: Wed, 13 Sep 2023 15:58:20 +0100 Subject: [PATCH] Fix conflicts from KGEMutateProc Co-authored-by: Olga Razvenskaia --- .../main/java/org/neo4j/gds/ml/core/tensor/Tensor.java | 10 ++++++++++ .../org/neo4j/gds/ml/kge/KGEMutateResultConsumer.java | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ml/ml-core/src/main/java/org/neo4j/gds/ml/core/tensor/Tensor.java b/ml/ml-core/src/main/java/org/neo4j/gds/ml/core/tensor/Tensor.java index 82459d76fa..6c4ef7e333 100644 --- a/ml/ml-core/src/main/java/org/neo4j/gds/ml/core/tensor/Tensor.java +++ b/ml/ml-core/src/main/java/org/neo4j/gds/ml/core/tensor/Tensor.java @@ -109,6 +109,16 @@ public int totalSize() { return Dimensions.totalSize(dimensions); } + // TODO: figure out how to replace this one + public SELF elementwiseProduct(Tensor other) { + var result = createWithSameDimensions(); + for (int i = 0; i < data.length; i++) { + result.data[i] = data[i] * other.data[i]; + } + return result; + } + + public Tensor elementwiseProductMutate(Tensor other) { for (int i = 0; i < data.length; i++) { this.data[i] = data[i] * other.data[i]; diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/kge/KGEMutateResultConsumer.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/kge/KGEMutateResultConsumer.java index 135a0c25da..014c526cf5 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/kge/KGEMutateResultConsumer.java +++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/kge/KGEMutateResultConsumer.java @@ -24,8 +24,8 @@ import org.neo4j.gds.RelationshipType; import org.neo4j.gds.ResultBuilderFunction; import org.neo4j.gds.core.Aggregation; +import org.neo4j.gds.core.concurrency.DefaultPool; import org.neo4j.gds.core.concurrency.ParallelUtil; -import org.neo4j.gds.core.concurrency.Pools; import org.neo4j.gds.core.loading.construction.GraphFactory; import org.neo4j.gds.core.utils.TerminationFlag; import org.neo4j.gds.executor.ComputationResult; @@ -64,7 +64,7 @@ protected void updateGraphStore( .orientation(Orientation.NATURAL) .addPropertyConfig(GraphFactory.PropertyConfig.of(KGE_PREDICT_MUTATE_PROPERTY)) .concurrency(concurrency) - .executorService(Pools.DEFAULT) + .executorService(DefaultPool.INSTANCE) .build(); var similarityResultStream = computationResult.result()