Skip to content

Commit

Permalink
[DistGB] refine convertion from dgl to graphbolt (#7007)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Jan 25, 2024
1 parent 0a2f40f commit 0bfe34d
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 37 deletions.
2 changes: 2 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,8 @@ pipeline {
steps {
unit_test_linux('tensorflow', 'cpu')
}
// Tensorflow is deprecated.
when { expression { false } }
}
}
post {
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .kvstore import KVClient, KVServer
from .nn import *
from .partition import (
convert_dgl_partition_to_csc_sampling_graph,
dgl_partition_to_graphbolt,
load_partition,
load_partition_book,
load_partition_feats,
Expand Down
120 changes: 100 additions & 20 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Functions for partitions. """

import copy
import json
import logging
import os
Expand All @@ -9,8 +10,8 @@

import torch

from .. import backend as F
from ..base import DGLError, EID, ETYPE, NID, NTYPE
from .. import backend as F, graphbolt as gb
from ..base import dgl_warning, DGLError, EID, ETYPE, NID, NTYPE
from ..convert import to_homogeneous
from ..data.utils import load_graphs, load_tensors, save_graphs, save_tensors
from ..partition import (
Expand Down Expand Up @@ -1223,12 +1224,18 @@ def get_homogeneous(g, balance_ntypes):
return orig_nids, orig_eids


def convert_dgl_partition_to_csc_sampling_graph(part_config):
def dgl_partition_to_graphbolt(
part_config,
*,
store_eids=False,
store_inner_node=False,
store_inner_edge=False,
):
"""Convert partitions of dgl to FusedCSCSamplingGraph of GraphBolt.
This API converts `DGLGraph` partitions to `FusedCSCSamplingGraph` which is
dedicated for sampling in `GraphBolt`. New graphs will be stored alongside
original graph as `fused_csc_sampling_graph.tar`.
original graph as `fused_csc_sampling_graph.pt`.
In the near future, partitions are supposed to be saved as
`FusedCSCSamplingGraph` directly. At that time, this API should be deprecated.
Expand All @@ -1237,42 +1244,106 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
----------
part_config : str
The partition configuration JSON file.
store_eids : bool, optional
Whether to store edge IDs in the new graph. Default: False.
store_inner_node : bool, optional
Whether to store inner node mask in the new graph. Default: False.
store_inner_edge : bool, optional
Whether to store inner edge mask in the new graph. Default: False.
"""

# As only this function requires GraphBolt for now, let's import here.
from .. import graphbolt

debug_mode = "DGL_DIST_DEBUG" in os.environ
if debug_mode:
dgl_warning(
"Running in debug mode which means all attributes of DGL partitions"
" will be saved to the new format."
)
part_meta = _load_part_config(part_config)
new_part_meta = copy.deepcopy(part_meta)
num_parts = part_meta["num_parts"]

# Utility functions.
def is_homogeneous(ntypes, etypes):
return len(ntypes) == 1 and len(etypes) == 1

def init_type_per_edge(graph, gpb):
etype_ids = gpb.map_to_per_etype(graph.edata[EID])[0]
return etype_ids

# [Rui] DGL partitions are always saved as homogeneous graphs even though
# the original graph is heterogeneous. But heterogeneous information like
# node/edge types are saved as node/edge data alongside with partitions.
# What needs more attention is that due to the existence of HALO nodes in
# each partition, the local node IDs are not sorted according to the node
# types. So we fail to assign ``node_type_offset`` as required by GraphBolt.
# But this is not a problem since such information is not used in sampling.
# We can simply pass None to it.

# Iterate over partitions.
for part_id in range(num_parts):
graph, _, _, gpb, _, _, _ = load_partition(
part_config, part_id, load_feats=False
)
_, _, ntypes, etypes = load_partition_book(part_config, part_id)
node_type_to_id = {ntype: ntid for ntid, ntype in enumerate(ntypes)}
edge_type_to_id = {
_etype_tuple_to_str(etype): etid
for etid, etype in enumerate(etypes)
}
is_homo = is_homogeneous(ntypes, etypes)
node_type_to_id = (
None
if is_homo
else {ntype: ntid for ntid, ntype in enumerate(ntypes)}
)
edge_type_to_id = (
None
if is_homo
else {
gb.etype_tuple_to_str(etype): etid
for etype, etid in etypes.items()
}
)
# Obtain CSC indtpr and indices.
indptr, indices, _ = graph.adj().csc()
# Initalize type per edge.
type_per_edge = init_type_per_edge(graph, gpb)
type_per_edge = type_per_edge.to(RESERVED_FIELD_DTYPE[ETYPE])
# Sanity check.
assert len(type_per_edge) == graph.num_edges()
csc_graph = graphbolt.fused_csc_sampling_graph(
indptr, indices, edge_ids = graph.adj_tensors("csc")

# Save node attributes. Detailed attributes are shown below.
# DGL_GB\Attributes dgl.NID("_ID") dgl.NTYPE("_TYPE") "inner_node" "part_id"
# DGL_Homograph ✅ 🚫 ✅ ✅
# GB_Homograph ✅ 🚫 optional 🚫
# DGL_Heterograph ✅ ✅ ✅ ✅
# GB_Heterograph ✅ 🚫 optional 🚫
required_node_attrs = [NID]
if store_inner_node:
required_node_attrs.append("inner_node")
if debug_mode:
required_node_attrs = list(graph.ndata.keys())
node_attributes = {
attr: graph.ndata[attr] for attr in required_node_attrs
}

# Save edge attributes. Detailed attributes are shown below.
# DGL_GB\Attributes dgl.EID("_ID") dgl.ETYPE("_TYPE") "inner_edge"
# DGL_Homograph ✅ 🚫 ✅
# GB_Homograph optional 🚫 optional
# DGL_Heterograph ✅ ✅ ✅
# GB_Heterograph optional ✅ optional
type_per_edge = None
if not is_homo:
type_per_edge = init_type_per_edge(graph, gpb)[edge_ids]
type_per_edge = type_per_edge.to(RESERVED_FIELD_DTYPE[ETYPE])
required_edge_attrs = []
if store_eids:
required_edge_attrs.append(EID)
if store_inner_edge:
required_edge_attrs.append("inner_edge")
if debug_mode:
required_edge_attrs = list(graph.edata.keys())
edge_attributes = {
attr: graph.edata[attr][edge_ids] for attr in required_edge_attrs
}

csc_graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=None,
type_per_edge=type_per_edge,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
)
Expand All @@ -1284,3 +1355,12 @@ def init_type_per_edge(graph, gpb):
os.path.dirname(orig_graph_path), "fused_csc_sampling_graph.pt"
)
torch.save(csc_graph, csc_graph_path)

# Update graph path.
new_part_meta[f"part-{part_id}"]["gb_part_graph"] = os.path.relpath(
csc_graph_path, os.path.dirname(part_config)
)

# Update partition config.
_dump_part_config(part_config, new_part_meta)
print(f"Converted partitions to GraphBolt format into {part_config}")
126 changes: 111 additions & 15 deletions tests/distributed/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch as th
from dgl import function as fn
from dgl.distributed import (
convert_dgl_partition_to_csc_sampling_graph,
dgl_partition_to_graphbolt,
load_partition,
load_partition_book,
load_partition_feats,
Expand Down Expand Up @@ -680,17 +680,34 @@ def test_UnknownPartitionBook():

@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("num_parts", [1, 4])
def test_convert_dgl_partition_to_csc_sampling_graph_homo(
part_method, num_parts
@pytest.mark.parametrize("store_eids", [True, False])
@pytest.mark.parametrize("store_inner_node", [True, False])
@pytest.mark.parametrize("store_inner_edge", [True, False])
@pytest.mark.parametrize("debug_mode", [True, False])
def test_dgl_partition_to_graphbolt_homo(
part_method,
num_parts,
store_eids,
store_inner_node,
store_inner_edge,
debug_mode,
):
reset_envs()
if debug_mode:
os.environ["DGL_DIST_DEBUG"] = "1"
with tempfile.TemporaryDirectory() as test_dir:
g = create_random_graph(1000)
graph_name = "test"
partition_graph(
g, graph_name, num_parts, test_dir, part_method=part_method
)
part_config = os.path.join(test_dir, f"{graph_name}.json")
convert_dgl_partition_to_csc_sampling_graph(part_config)
dgl_partition_to_graphbolt(
part_config,
store_eids=store_eids,
store_inner_node=store_inner_node,
store_inner_edge=store_inner_edge,
)
for part_id in range(num_parts):
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
Expand All @@ -700,30 +717,69 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
test_dir, f"part{part_id}/fused_csc_sampling_graph.pt"
)
)
orig_indptr, orig_indices, _ = orig_g.adj().csc()
orig_indptr, orig_indices, orig_eids = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices)
assert new_g.node_type_offset is None
assert all(new_g.type_per_edge == 0)
for node_type, type_id in new_g.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.edge_type_to_id.items():
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
assert th.equal(
orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]
)
if store_inner_node or debug_mode:
assert th.equal(
orig_g.ndata["inner_node"],
new_g.node_attributes["inner_node"],
)
else:
assert "inner_node" not in new_g.node_attributes
if store_eids or debug_mode:
assert th.equal(
orig_g.edata[dgl.EID][orig_eids],
new_g.edge_attributes[dgl.EID],
)
else:
assert dgl.EID not in new_g.edge_attributes
if store_inner_edge or debug_mode:
assert th.equal(
orig_g.edata["inner_edge"][orig_eids],
new_g.edge_attributes["inner_edge"],
)
else:
assert "inner_edge" not in new_g.edge_attributes
assert new_g.type_per_edge is None
assert new_g.node_type_to_id is None
assert new_g.edge_type_to_id is None


@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("num_parts", [1, 4])
def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
part_method, num_parts
@pytest.mark.parametrize("store_eids", [True, False])
@pytest.mark.parametrize("store_inner_node", [True, False])
@pytest.mark.parametrize("store_inner_edge", [True, False])
@pytest.mark.parametrize("debug_mode", [True, False])
def test_dgl_partition_to_graphbolt_hetero(
part_method,
num_parts,
store_eids,
store_inner_node,
store_inner_edge,
debug_mode,
):
reset_envs()
if debug_mode:
os.environ["DGL_DIST_DEBUG"] = "1"
with tempfile.TemporaryDirectory() as test_dir:
g = create_random_hetero()
graph_name = "test"
partition_graph(
g, graph_name, num_parts, test_dir, part_method=part_method
)
part_config = os.path.join(test_dir, f"{graph_name}.json")
convert_dgl_partition_to_csc_sampling_graph(part_config)
dgl_partition_to_graphbolt(
part_config,
store_eids=store_eids,
store_inner_node=store_inner_node,
store_inner_edge=store_inner_edge,
)
for part_id in range(num_parts):
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
Expand All @@ -733,15 +789,55 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
test_dir, f"part{part_id}/fused_csc_sampling_graph.pt"
)
)
orig_indptr, orig_indices, _ = orig_g.adj().csc()
orig_indptr, orig_indices, orig_eids = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices)
assert th.equal(
orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]
)
if store_inner_node or debug_mode:
assert th.equal(
orig_g.ndata["inner_node"],
new_g.node_attributes["inner_node"],
)
else:
assert "inner_node" not in new_g.node_attributes
if debug_mode:
assert th.equal(
orig_g.ndata[dgl.NTYPE], new_g.node_attributes[dgl.NTYPE]
)
else:
assert dgl.NTYPE not in new_g.node_attributes
if store_eids or debug_mode:
assert th.equal(
orig_g.edata[dgl.EID][orig_eids],
new_g.edge_attributes[dgl.EID],
)
else:
assert dgl.EID not in new_g.edge_attributes
if store_inner_edge or debug_mode:
assert th.equal(
orig_g.edata["inner_edge"],
new_g.edge_attributes["inner_edge"],
)
else:
assert "inner_edge" not in new_g.edge_attributes
if debug_mode:
assert th.equal(
orig_g.edata[dgl.ETYPE][orig_eids],
new_g.edge_attributes[dgl.ETYPE],
)
else:
assert dgl.ETYPE not in new_g.edge_attributes
assert th.equal(
orig_g.edata[dgl.ETYPE][orig_eids], new_g.type_per_edge
)

for node_type, type_id in new_g.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.edge_type_to_id.items():
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
assert new_g.node_type_offset is None
assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge)


def test_not_sorted_node_edge_map():
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/cugraph_unit_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ export TF_FORCE_GPU_ALLOW_GROWTH=true

export CUDA_VISIBLE_DEVICES=0

python3 -m pip install pytest psutil pyyaml pydantic pandas rdflib ogb || fail "pip install"
python3 -m pip install pytest psutil pyyaml pydantic pandas rdflib ogb torchdata || fail "pip install"

python3 -m pytest -v --junitxml=pytest_cugraph.xml --durations=20 tests/cugraph || fail "cugraph"

0 comments on commit 0bfe34d

Please sign in to comment.