Skip to content

Commit

Permalink
Merge branch 'master' into gb_refactor_neighbor_sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jan 30, 2024
2 parents cd68728 + b085224 commit c3a903d
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 61 deletions.
5 changes: 1 addition & 4 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ def create_dataloader(
if args.storage_device == "cpu":
datapipe = datapipe.copy_to(device)

# Until https://github.com/dmlc/dgl/issues/7008, overlap should be False.
dataloader = gb.DataLoader(
datapipe, args.num_workers, overlap_feature_fetch=False
)
dataloader = gb.DataLoader(datapipe, args.num_workers)

# Return the fully-initialized DataLoader object.
return dataloader
Expand Down
65 changes: 40 additions & 25 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from .. import backend as F, heterograph_index
from .. import backend as F, graphbolt as gb, heterograph_index
from .._ffi.ndarray import empty_shared_mem
from ..base import ALL, DGLError, EID, ETYPE, is_all, NID
from ..convert import graph as dgl_graph, heterograph as dgl_heterograph
Expand Down Expand Up @@ -88,7 +88,9 @@ def __setstate__(self, state):
self._graph_name = state


def _copy_graph_to_shared_mem(g, graph_name, graph_format):
def _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt):
if use_graphbolt:
return g.copy_to_shared_memory(graph_name)
new_g = g.shared_memory(graph_name, formats=graph_format)
# We should share the node/edge data to the client explicitly instead of putting them
# in the KVStore because some of the node/edge data may be duplicated.
Expand Down Expand Up @@ -298,6 +300,30 @@ def __repr__(self):
return repr(reprs)


def _format_partition(graph, graph_format):
"""Format the partition to the specified format."""
if isinstance(graph, gb.FusedCSCSamplingGraph):
return graph
# formatting dtype
# TODO(Rui) Formatting forcely is not a perfect solution.
# We'd better store all dtypes when mapping to shared memory
# and map back with original dtypes.
for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in graph.ndata:
graph.ndata[k] = F.astype(graph.ndata[k], dtype)
if k in graph.edata:
graph.edata[k] = F.astype(graph.edata[k], dtype)
# Create the graph formats specified the users.
print(
"Start to create specified graph formats which may take "
"non-trivial time."
)
graph = graph.formats(graph_format)
graph.create_formats_()
print(f"Finished creating specified graph formats: {graph_format}")
return graph


class DistGraphServer(KVServer):
"""The DistGraph server.
Expand Down Expand Up @@ -330,6 +356,8 @@ class DistGraphServer(KVServer):
Disable shared memory.
graph_format : str or list of str
The graph formats.
use_graphbolt : bool
Whether to load GraphBolt partition. Default: False.
"""

def __init__(
Expand All @@ -341,6 +369,7 @@ def __init__(
part_config,
disable_shared_mem=False,
graph_format=("csc", "coo"),
use_graphbolt=False,
):
super(DistGraphServer, self).__init__(
server_id=server_id,
Expand All @@ -350,6 +379,7 @@ def __init__(
)
self.ip_config = ip_config
self.num_servers = num_servers
self.use_graphbolt = use_graphbolt
# Load graph partition data.
if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards.
Expand All @@ -367,32 +397,17 @@ def __init__(
graph_name,
ntypes,
etypes,
) = load_partition(part_config, self.part_id, load_feats=False)
print("load " + graph_name)
# formatting dtype
# TODO(Rui) Formatting forcely is not a perfect solution.
# We'd better store all dtypes when mapping to shared memory
# and map back with original dtypes.
for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in self.client_g.ndata:
self.client_g.ndata[k] = F.astype(
self.client_g.ndata[k], dtype
)
if k in self.client_g.edata:
self.client_g.edata[k] = F.astype(
self.client_g.edata[k], dtype
)
# Create the graph formats specified the users.
print(
"Start to create specified graph formats which may take "
"non-trivial time."
) = load_partition(
part_config,
self.part_id,
load_feats=False,
use_graphbolt=use_graphbolt,
)
self.client_g = self.client_g.formats(graph_format)
self.client_g.create_formats_()
print("Finished creating specified graph formats.")
print("load " + graph_name)
self.client_g = _format_partition(self.client_g, graph_format)
if not disable_shared_mem:
self.client_g = _copy_graph_to_shared_mem(
self.client_g, graph_name, graph_format
self.client_g, graph_name, graph_format, use_graphbolt
)

if not disable_shared_mem:
Expand Down
42 changes: 16 additions & 26 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@

__all__ = [
"DataLoader",
"Awaiter",
"Bufferer",
]


def _find_and_wrap_parent(
datapipe_graph, datapipe_adjlist, target_datapipe, wrapper, **kwargs
):
def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
"""Find parent of target_datapipe and wrap it with ."""
datapipes = dp_utils.find_dps(
datapipe_graph,
target_datapipe,
)
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
for datapipe in datapipes:
datapipe_id = id(datapipe)
for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]:
Expand All @@ -36,6 +37,7 @@ def _find_and_wrap_parent(
parent_datapipe,
wrapper(parent_datapipe, **kwargs),
)
return datapipe_graph


class EndMarker(dp.iter.IterDataPipe):
Expand All @@ -45,8 +47,7 @@ def __init__(self, datapipe):
self.datapipe = datapipe

def __iter__(self):
for data in self.datapipe:
yield data
yield from self.datapipe


class Bufferer(dp.iter.IterDataPipe):
Expand All @@ -58,11 +59,11 @@ class Bufferer(dp.iter.IterDataPipe):
The data pipeline.
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider increasing passing a high
value. Default is 2.
from datapipe has latency spikes, consider setting to a higher value.
Default is 1.
"""

def __init__(self, datapipe, buffer_size=2):
def __init__(self, datapipe, buffer_size=1):
self.datapipe = datapipe
if buffer_size <= 0:
raise ValueError(
Expand Down Expand Up @@ -180,7 +181,6 @@ def __init__(

datapipe = EndMarker(datapipe)
datapipe_graph = dp_utils.traverse_dps(datapipe)
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)

# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
Expand All @@ -198,9 +198,8 @@ def __init__(
)

# (2) Cut datapipe at FeatureFetcher and wrap.
_find_and_wrap_parent(
datapipe_graph = _find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
FeatureFetcher,
MultiprocessingWrapper,
num_workers=num_workers,
Expand All @@ -221,25 +220,16 @@ def __init__(
)
for feature_fetcher in feature_fetchers:
feature_fetcher.stream = _get_uva_stream()
_find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
EndMarker,
Bufferer,
buffer_size=2,
)
_find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
EndMarker,
Awaiter,
)
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
feature_fetcher,
Awaiter(Bufferer(feature_fetcher, buffer_size=1)),
)

# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# data pipeline up to the CopyTo operation to run in a separate thread.
_find_and_wrap_parent(
datapipe_graph = _find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
CopyTo,
dp.iter.Prefetcher,
buffer_size=2,
Expand Down
29 changes: 23 additions & 6 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pytest
import torch

import torchdata.dataloader2.graph as dp_utils

from . import gb_test_utils


Expand Down Expand Up @@ -46,7 +48,8 @@ def test_DataLoader():
reason="This test requires the GPU.",
)
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
def test_gpu_sampling_DataLoader(overlap_feature_fetch):
@pytest.mark.parametrize("enable_feature_fetch", [True, False])
def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
N = 40
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
Expand All @@ -70,13 +73,27 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch):
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
datapipe = dgl.graphbolt.FeatureFetcher(
datapipe,
feature_store,
["a", "b"],
)
if enable_feature_fetch:
datapipe = dgl.graphbolt.FeatureFetcher(
datapipe,
feature_store,
["a", "b"],
)

dataloader = dgl.graphbolt.DataLoader(
datapipe, overlap_feature_fetch=overlap_feature_fetch
)
bufferer_awaiter_cnt = int(enable_feature_fetch and overlap_feature_fetch)
datapipe = dataloader.dataset
datapipe_graph = dp_utils.traverse_dps(datapipe)
awaiters = dp_utils.find_dps(
datapipe_graph,
dgl.graphbolt.Awaiter,
)
assert len(awaiters) == bufferer_awaiter_cnt
bufferers = dp_utils.find_dps(
datapipe_graph,
dgl.graphbolt.Bufferer,
)
assert len(bufferers) == bufferer_awaiter_cnt
assert len(list(dataloader)) == N // B

0 comments on commit c3a903d

Please sign in to comment.