From 19ac19c46673e415aab845f51619807cb2226b08 Mon Sep 17 00:00:00 2001 From: RhettYing Date: Tue, 30 Jan 2024 02:29:26 +0000 Subject: [PATCH] [DistGB] enable DistGraphServer to load graphbolt partitions --- python/dgl/distributed/dist_graph.py | 65 +++++++++++++++++----------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/python/dgl/distributed/dist_graph.py b/python/dgl/distributed/dist_graph.py index 192293a80676..4defa3223937 100644 --- a/python/dgl/distributed/dist_graph.py +++ b/python/dgl/distributed/dist_graph.py @@ -8,7 +8,7 @@ import numpy as np -from .. import backend as F, heterograph_index +from .. import backend as F, graphbolt as gb, heterograph_index from .._ffi.ndarray import empty_shared_mem from ..base import ALL, DGLError, EID, ETYPE, is_all, NID from ..convert import graph as dgl_graph, heterograph as dgl_heterograph @@ -88,7 +88,9 @@ def __setstate__(self, state): self._graph_name = state -def _copy_graph_to_shared_mem(g, graph_name, graph_format): +def _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt): + if use_graphbolt: + return g.copy_to_shared_memory(graph_name) new_g = g.shared_memory(graph_name, formats=graph_format) # We should share the node/edge data to the client explicitly instead of putting them # in the KVStore because some of the node/edge data may be duplicated. @@ -298,6 +300,30 @@ def __repr__(self): return repr(reprs) +def _format_partition(graph, graph_format): + """Format the partition to the specified format.""" + if isinstance(graph, gb.FusedCSCSamplingGraph): + return graph + # formatting dtype + # TODO(Rui) Formatting forcely is not a perfect solution. + # We'd better store all dtypes when mapping to shared memory + # and map back with original dtypes. + for k, dtype in RESERVED_FIELD_DTYPE.items(): + if k in graph.ndata: + graph.ndata[k] = F.astype(graph.ndata[k], dtype) + if k in graph.edata: + graph.edata[k] = F.astype(graph.edata[k], dtype) + # Create the graph formats specified the users. + print( + "Start to create specified graph formats which may take " + "non-trivial time." + ) + graph = graph.formats(graph_format) + graph.create_formats_() + print(f"Finished creating specified graph formats: {graph_format}") + return graph + + class DistGraphServer(KVServer): """The DistGraph server. @@ -330,6 +356,8 @@ class DistGraphServer(KVServer): Disable shared memory. graph_format : str or list of str The graph formats. + use_graphbolt : bool + Whether to load GraphBolt partition. Default: False. """ def __init__( @@ -341,6 +369,7 @@ def __init__( part_config, disable_shared_mem=False, graph_format=("csc", "coo"), + use_graphbolt=False, ): super(DistGraphServer, self).__init__( server_id=server_id, @@ -350,6 +379,7 @@ def __init__( ) self.ip_config = ip_config self.num_servers = num_servers + self.use_graphbolt = use_graphbolt # Load graph partition data. if self.is_backup_server(): # The backup server doesn't load the graph partition. It'll initialized afterwards. @@ -367,32 +397,17 @@ def __init__( graph_name, ntypes, etypes, - ) = load_partition(part_config, self.part_id, load_feats=False) - print("load " + graph_name) - # formatting dtype - # TODO(Rui) Formatting forcely is not a perfect solution. - # We'd better store all dtypes when mapping to shared memory - # and map back with original dtypes. - for k, dtype in RESERVED_FIELD_DTYPE.items(): - if k in self.client_g.ndata: - self.client_g.ndata[k] = F.astype( - self.client_g.ndata[k], dtype - ) - if k in self.client_g.edata: - self.client_g.edata[k] = F.astype( - self.client_g.edata[k], dtype - ) - # Create the graph formats specified the users. - print( - "Start to create specified graph formats which may take " - "non-trivial time." + ) = load_partition( + part_config, + self.part_id, + load_feats=False, + use_graphbolt=use_graphbolt, ) - self.client_g = self.client_g.formats(graph_format) - self.client_g.create_formats_() - print("Finished creating specified graph formats.") + print("load " + graph_name) + self.client_g = _format_partition(self.client_g, graph_format) if not disable_shared_mem: self.client_g = _copy_graph_to_shared_mem( - self.client_g, graph_name, graph_format + self.client_g, graph_name, graph_format, use_graphbolt ) if not disable_shared_mem: