Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DistGB] enable DistGraphServer to load graphbolt partitions #7042

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 40 additions & 25 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from .. import backend as F, heterograph_index
from .. import backend as F, graphbolt as gb, heterograph_index
from .._ffi.ndarray import empty_shared_mem
from ..base import ALL, DGLError, EID, ETYPE, is_all, NID
from ..convert import graph as dgl_graph, heterograph as dgl_heterograph
Expand Down Expand Up @@ -88,7 +88,9 @@ def __setstate__(self, state):
self._graph_name = state


def _copy_graph_to_shared_mem(g, graph_name, graph_format):
def _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt):
if use_graphbolt:
return g.copy_to_shared_memory(graph_name)
new_g = g.shared_memory(graph_name, formats=graph_format)
# We should share the node/edge data to the client explicitly instead of putting them
# in the KVStore because some of the node/edge data may be duplicated.
Expand Down Expand Up @@ -298,6 +300,30 @@ def __repr__(self):
return repr(reprs)


def _format_partition(graph, graph_format):
"""Format the partition to the specified format."""
if isinstance(graph, gb.FusedCSCSamplingGraph):
return graph
# formatting dtype
# TODO(Rui) Formatting forcely is not a perfect solution.
# We'd better store all dtypes when mapping to shared memory
# and map back with original dtypes.
for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in graph.ndata:
graph.ndata[k] = F.astype(graph.ndata[k], dtype)
if k in graph.edata:
graph.edata[k] = F.astype(graph.edata[k], dtype)
# Create the graph formats specified the users.
print(
"Start to create specified graph formats which may take "
"non-trivial time."
)
graph = graph.formats(graph_format)
graph.create_formats_()
print(f"Finished creating specified graph formats: {graph_format}")
return graph


class DistGraphServer(KVServer):
"""The DistGraph server.

Expand Down Expand Up @@ -330,6 +356,8 @@ class DistGraphServer(KVServer):
Disable shared memory.
graph_format : str or list of str
The graph formats.
use_graphbolt : bool
Whether to load GraphBolt partition. Default: False.
"""

def __init__(
Expand All @@ -341,6 +369,7 @@ def __init__(
part_config,
disable_shared_mem=False,
graph_format=("csc", "coo"),
use_graphbolt=False,
):
super(DistGraphServer, self).__init__(
server_id=server_id,
Expand All @@ -350,6 +379,7 @@ def __init__(
)
self.ip_config = ip_config
self.num_servers = num_servers
self.use_graphbolt = use_graphbolt
# Load graph partition data.
if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards.
Expand All @@ -367,32 +397,17 @@ def __init__(
graph_name,
ntypes,
etypes,
) = load_partition(part_config, self.part_id, load_feats=False)
print("load " + graph_name)
# formatting dtype
# TODO(Rui) Formatting forcely is not a perfect solution.
# We'd better store all dtypes when mapping to shared memory
# and map back with original dtypes.
for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in self.client_g.ndata:
self.client_g.ndata[k] = F.astype(
self.client_g.ndata[k], dtype
)
if k in self.client_g.edata:
self.client_g.edata[k] = F.astype(
self.client_g.edata[k], dtype
)
# Create the graph formats specified the users.
print(
"Start to create specified graph formats which may take "
"non-trivial time."
) = load_partition(
part_config,
self.part_id,
load_feats=False,
use_graphbolt=use_graphbolt,
)
self.client_g = self.client_g.formats(graph_format)
self.client_g.create_formats_()
print("Finished creating specified graph formats.")
print("load " + graph_name)
self.client_g = _format_partition(self.client_g, graph_format)
if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(
self.client_g, graph_name, graph_format
self.client_g, graph_name, graph_format, use_graphbolt
)

if not disable_shared_mem:
Expand Down
Loading