Skip to content

Commit

Permalink
Merge pull request #8131 from lassewesth/samples
Browse files Browse the repository at this point in the history
make graph sampling reusable
  • Loading branch information
lassewesth authored Sep 12, 2023
2 parents 1e1a5f3 + 391ac32 commit 23ddb43
Show file tree
Hide file tree
Showing 17 changed files with 606 additions and 252 deletions.
1 change: 1 addition & 0 deletions applications/graph-store-catalog/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies {
implementation project(':core-write')
implementation project(':executor')
implementation project(':graph-projection-api')
implementation project(':graph-sampling')
implementation project(':logging')
implementation project(':memory-usage')
implementation project(':native-projection')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.neo4j.gds.core.CypherMapAccess;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.loading.GraphStoreWithConfig;
import org.neo4j.gds.graphsampling.config.CommonNeighbourAwareRandomWalkConfig;
import org.neo4j.gds.projection.GraphProjectFromStoreConfig;

import java.util.Collection;
Expand Down Expand Up @@ -329,6 +330,12 @@ GraphWriteRelationshipConfig parseGraphWriteRelationshipConfiguration(
return configuration;
}

CommonNeighbourAwareRandomWalkConfig parseCommonNeighbourAwareRandomWalkConfig(Map<String, Object> rawConfiguration) {
var cypherConfig = CypherMapWrapper.create(rawConfiguration);

return CommonNeighbourAwareRandomWalkConfig.of(cypherConfig);
}

private void ensureThereAreNoExtraConfigurationKeys(CypherMapAccess cypherConfig, BaseConfig config) {
cypherConfig.requireOnlyKeysFrom(config.configKeys());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.neo4j.gds.api.User;
import org.neo4j.gds.config.MutateLabelConfig;
import org.neo4j.gds.config.WriteLabelConfig;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.loading.CatalogRequest;
import org.neo4j.gds.core.loading.GraphDropNodePropertiesResult;
import org.neo4j.gds.core.loading.GraphDropRelationshipResult;
Expand All @@ -35,16 +36,24 @@
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
import org.neo4j.gds.core.utils.warnings.UserLogRegistryFactory;
import org.neo4j.gds.graphsampling.RandomWalkBasedNodesSampler;
import org.neo4j.gds.graphsampling.config.RandomWalkWithRestartsConfig;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.projection.GraphProjectNativeResult;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.gds.transaction.TransactionContext;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.neo4j.gds.applications.graphstorecatalog.SamplerCompanion.CNARW_CONFIG_PROVIDER;
import static org.neo4j.gds.applications.graphstorecatalog.SamplerCompanion.CNARW_PROVIDER;
import static org.neo4j.gds.applications.graphstorecatalog.SamplerCompanion.RWR_CONFIG_PROVIDER;
import static org.neo4j.gds.applications.graphstorecatalog.SamplerCompanion.RWR_PROVIDER;

/**
* This layer is shared between Neo4j and other integrations. It is entry-point agnostic.
* "Business facade" to distinguish it from "procedure facade" and similar.
Expand Down Expand Up @@ -90,6 +99,8 @@ public class DefaultGraphStoreCatalogBusinessFacade implements GraphStoreCatalog
private final WriteRelationshipPropertiesApplication writeRelationshipPropertiesApplication;
private final WriteNodeLabelApplication writeNodeLabelApplication;
private final WriteRelationshipsApplication writeRelationshipsApplication;
private final GraphSamplingApplication graphSamplingApplication;
private final EstimateCommonNeighbourAwareRandomWalkApplication estimateCommonNeighbourAwareRandomWalkApplication;

public DefaultGraphStoreCatalogBusinessFacade(
Log log,
Expand All @@ -112,7 +123,9 @@ public DefaultGraphStoreCatalogBusinessFacade(
WriteNodePropertiesApplication writeNodePropertiesApplication,
WriteRelationshipPropertiesApplication writeRelationshipPropertiesApplication,
WriteNodeLabelApplication writeNodeLabelApplication,
WriteRelationshipsApplication writeRelationshipsApplication
WriteRelationshipsApplication writeRelationshipsApplication,
GraphSamplingApplication graphSamplingApplication,
EstimateCommonNeighbourAwareRandomWalkApplication estimateCommonNeighbourAwareRandomWalkApplication
) {
this.log = log;

Expand All @@ -137,6 +150,8 @@ public DefaultGraphStoreCatalogBusinessFacade(
this.writeRelationshipPropertiesApplication = writeRelationshipPropertiesApplication;
this.writeNodeLabelApplication = writeNodeLabelApplication;
this.writeRelationshipsApplication = writeRelationshipsApplication;
this.graphSamplingApplication = graphSamplingApplication;
this.estimateCommonNeighbourAwareRandomWalkApplication = estimateCommonNeighbourAwareRandomWalkApplication;
}

@Override
Expand Down Expand Up @@ -753,6 +768,96 @@ public WriteRelationshipResult writeRelationships(
);
}

@Override
public RandomWalkSamplingResult sampleRandomWalkWithRestarts(
User user,
DatabaseId databaseId,
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
String graphName,
String originGraphName,
Map<String, Object> configuration
) {
return sampleRandomWalk(
user,
databaseId,
taskRegistryFactory,
userLogRegistryFactory,
graphName,
originGraphName,
configuration,
RWR_CONFIG_PROVIDER,
RWR_PROVIDER
);
}

@Override
public RandomWalkSamplingResult sampleCommonNeighbourAwareRandomWalk(
User user,
DatabaseId databaseId,
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
String graphNameAsString,
String originGraphName,
Map<String, Object> configuration
) {
return sampleRandomWalk(
user,
databaseId,
taskRegistryFactory,
userLogRegistryFactory,
graphNameAsString,
originGraphName,
configuration,
CNARW_CONFIG_PROVIDER,
CNARW_PROVIDER
);
}

@Override
public MemoryEstimateResult estimateCommonNeighbourAwareRandomWalk(
User user,
DatabaseId databaseId,
String graphName,
Map<String, Object> rawConfiguration
) {
var configuration = configurationService.parseCommonNeighbourAwareRandomWalkConfig(rawConfiguration);

return estimateCommonNeighbourAwareRandomWalkApplication.estimate(user, databaseId, graphName, configuration);
}

private RandomWalkSamplingResult sampleRandomWalk(
User user,
DatabaseId databaseId,
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
String graphNameAsString,
String originGraphNameAsString,
Map<String, Object> configuration,
Function<CypherMapWrapper, RandomWalkWithRestartsConfig> samplerConfigProvider,
Function<RandomWalkWithRestartsConfig, RandomWalkBasedNodesSampler> samplerAlgorithmProvider
) {
var graphName = ensureGraphNameValidAndUnknown(user, databaseId, graphNameAsString);
var originGraphName = GraphName.parse(originGraphNameAsString);

var graphStoreWithConfig = graphStoreCatalogService.get(CatalogRequest.of(user, databaseId), originGraphName);
var graphStore = graphStoreWithConfig.graphStore();
var graphProjectConfig = graphStoreWithConfig.config();

return graphSamplingApplication.sample(
user,
taskRegistryFactory,
userLogRegistryFactory,
graphStore,
graphProjectConfig,
originGraphName,
graphName,
configuration,
samplerConfigProvider,
samplerAlgorithmProvider
);
}

private GraphName ensureGraphNameValidAndUnknown(User user, DatabaseId databaseId, String graphNameAsString) {
var graphName = graphNameValidationService.validateStrictly(graphNameAsString);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.applications.graphstorecatalog;

import org.neo4j.gds.api.DatabaseId;
import org.neo4j.gds.api.User;
import org.neo4j.gds.core.utils.mem.MemoryTreeWithDimensions;
import org.neo4j.gds.executor.GraphStoreFromCatalogLoader;
import org.neo4j.gds.graphsampling.config.CommonNeighbourAwareRandomWalkConfig;
import org.neo4j.gds.graphsampling.samplers.rw.cnarw.CommonNeighbourAwareRandomWalk;
import org.neo4j.gds.results.MemoryEstimateResult;

public class EstimateCommonNeighbourAwareRandomWalkApplication {
MemoryEstimateResult estimate(
User user,
DatabaseId databaseId,
String graphName,
CommonNeighbourAwareRandomWalkConfig configuration
) {
var loader = new GraphStoreFromCatalogLoader(
graphName,
configuration,
user.getUsername(),
databaseId,
user.isAdmin()
);

var memoryTree = CommonNeighbourAwareRandomWalk
.memoryEstimation(configuration)
.estimate(loader.graphDimensions(), configuration.concurrency());

var memoryTreeWithDimensions = new MemoryTreeWithDimensions(memoryTree, loader.graphDimensions());

return new MemoryEstimateResult(memoryTreeWithDimensions);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.applications.graphstorecatalog;

import org.neo4j.gds.api.GraphName;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.User;
import org.neo4j.gds.config.GraphProjectConfig;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.loading.GraphStoreCatalogService;
import org.neo4j.gds.core.utils.ProgressTimer;
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
import org.neo4j.gds.core.utils.warnings.UserLogRegistryFactory;
import org.neo4j.gds.graphsampling.GraphSampleConstructor;
import org.neo4j.gds.graphsampling.RandomWalkBasedNodesSampler;
import org.neo4j.gds.graphsampling.config.RandomWalkWithRestartsConfig;
import org.neo4j.gds.logging.Log;

import java.util.Map;
import java.util.function.Function;

public final class GraphSamplingApplication {
private final Log log;
private final GraphStoreCatalogService graphStoreCatalogService;

public GraphSamplingApplication(Log log, GraphStoreCatalogService graphStoreCatalogService) {
this.log = log;
this.graphStoreCatalogService = graphStoreCatalogService;
}

RandomWalkSamplingResult sample(
User user,
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
GraphStore graphStore,
GraphProjectConfig graphProjectConfig,
GraphName originGraphName,
GraphName graphName,
Map<String, Object> configuration,
Function<CypherMapWrapper, RandomWalkWithRestartsConfig> samplerConfigProvider,
Function<RandomWalkWithRestartsConfig, RandomWalkBasedNodesSampler> samplerAlgorithmProvider
) {
try (var progressTimer = ProgressTimer.start()) {
var cypherMap = CypherMapWrapper.create(configuration);
var samplerConfig = samplerConfigProvider.apply(cypherMap);

var samplerAlgorithm = samplerAlgorithmProvider.apply(samplerConfig);
var progressTracker = new TaskProgressTracker(
GraphSampleConstructor.progressTask(graphStore, samplerAlgorithm),
(org.neo4j.logging.Log) log.getNeo4jLog(),
samplerConfig.concurrency(),
samplerConfig.jobId(),
taskRegistryFactory,
userLogRegistryFactory
);
var graphSampleConstructor = new GraphSampleConstructor(
samplerConfig,
graphStore,
samplerAlgorithm,
progressTracker
);
var sampledGraphStore = graphSampleConstructor.compute();

var rwrProcConfig = RandomWalkWithRestartsConfiguration.of(
user.getUsername(),
graphName.getValue(),
originGraphName.getValue(),
graphProjectConfig,
cypherMap
);

graphStoreCatalogService.set(rwrProcConfig, sampledGraphStore);

var projectMillis = progressTimer.stop().getDuration();

return new RandomWalkSamplingResult(
graphName.getValue(),
originGraphName.getValue(),
sampledGraphStore.nodeCount(),
sampledGraphStore.relationshipCount(),
samplerAlgorithm.startNodesCount(),
projectMillis
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,31 @@ WriteRelationshipResult writeRelationships(
String relationshipProperty,
Map<String, Object> configuration
);

RandomWalkSamplingResult sampleRandomWalkWithRestarts(
User user,
DatabaseId databaseId,
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
String graphName,
String originGraphName,
Map<String, Object> configuration
);

RandomWalkSamplingResult sampleCommonNeighbourAwareRandomWalk(
User user,
DatabaseId databaseId,
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
String graphName,
String originGraphName,
Map<String, Object> configuration
);

MemoryEstimateResult estimateCommonNeighbourAwareRandomWalk(
User user,
DatabaseId databaseId,
String graphName,
Map<String, Object> configuration
);
}
Loading

0 comments on commit 23ddb43

Please sign in to comment.