diff --git a/python/dgl/distributed/dist_graph.py b/python/dgl/distributed/dist_graph.py index 4defa3223937..5bf76498ec97 100644 --- a/python/dgl/distributed/dist_graph.py +++ b/python/dgl/distributed/dist_graph.py @@ -60,18 +60,21 @@ class InitGraphRequest(rpc.Request): with shared memory. """ - def __init__(self, graph_name): + def __init__(self, graph_name, use_graphbolt): self._graph_name = graph_name + self._use_graphbolt = use_graphbolt def __getstate__(self): - return self._graph_name + return self._graph_name, self._use_graphbolt def __setstate__(self, state): - self._graph_name = state + self._graph_name, self._use_graphbolt = state def process_request(self, server_state): if server_state.graph is None: - server_state.graph = _get_graph_from_shared_mem(self._graph_name) + server_state.graph = _get_graph_from_shared_mem( + self._graph_name, self._use_graphbolt + ) return InitGraphResponse(self._graph_name) @@ -153,13 +156,15 @@ def _exist_shared_mem_array(graph_name, name): return exist_shared_mem_array(_get_edata_path(graph_name, name)) -def _get_graph_from_shared_mem(graph_name): +def _get_graph_from_shared_mem(graph_name, use_graphbolt): """Get the graph from the DistGraph server. The DistGraph server puts the graph structure of the local partition in the shared memory. The client can access the graph structure and some metadata on nodes and edges directly through shared memory to reduce the overhead of data access. """ + if use_graphbolt: + return gb.load_from_shared_memory(graph_name) g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory( graph_name ) @@ -524,6 +529,8 @@ class DistGraph: part_config : str, optional The path of partition configuration file generated by :py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode. + use_graphbolt : bool, optional + Whether to load GraphBolt partition. Default: False. Examples -------- @@ -557,9 +564,15 @@ class DistGraph: manually setting up servers and trainers. The setup is not fully tested yet. """ - def __init__(self, graph_name, gpb=None, part_config=None): + def __init__( + self, graph_name, gpb=None, part_config=None, use_graphbolt=False + ): self.graph_name = graph_name + self._use_graphbolt = use_graphbolt if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone": + assert ( + use_graphbolt is False + ), "GraphBolt is not supported in standalone mode." assert ( part_config is not None ), "When running in the standalone model, the partition config file is required" @@ -600,7 +613,9 @@ def __init__(self, graph_name, gpb=None, part_config=None): self._init(gpb) # Tell the backup servers to load the graph structure from shared memory. for server_id in range(self._client.num_servers): - rpc.send_request(server_id, InitGraphRequest(graph_name)) + rpc.send_request( + server_id, InitGraphRequest(graph_name, use_graphbolt) + ) for server_id in range(self._client.num_servers): rpc.recv_response() self._client.barrier() @@ -625,7 +640,9 @@ def _init(self, gpb): assert ( self._client is not None ), "Distributed module is not initialized. Please call dgl.distributed.initialize." - self._g = _get_graph_from_shared_mem(self.graph_name) + self._g = _get_graph_from_shared_mem( + self.graph_name, self._use_graphbolt + ) self._gpb = get_shared_mem_partition_book(self.graph_name) if self._gpb is None: self._gpb = gpb @@ -682,10 +699,10 @@ def _init_edata_store(self): self._edata_store[etype] = data def __getstate__(self): - return self.graph_name, self._gpb + return self.graph_name, self._gpb, self._use_graphbolt def __setstate__(self, state): - self.graph_name, gpb = state + self.graph_name, gpb, self._use_graphbolt = state self._init(gpb) self._init_ndata_store() @@ -1230,6 +1247,9 @@ def find_edges(self, edges, etype=None): tensor The destination node ID array. """ + assert ( + self._use_graphbolt is False + ), "find_edges is not supported in GraphBolt." if etype is None: assert ( len(self.etypes) == 1 diff --git a/tests/distributed/test_dist_graph_store.py b/tests/distributed/test_dist_graph_store.py index b473ef163215..63b70c3cd2be 100644 --- a/tests/distributed/test_dist_graph_store.py +++ b/tests/distributed/test_dist_graph_store.py @@ -13,11 +13,13 @@ import backend as F import dgl +import dgl.graphbolt as gb import numpy as np import pytest import torch as th from dgl.data.utils import load_graphs, save_graphs from dgl.distributed import ( + dgl_partition_to_graphbolt, DistEmbedding, DistGraph, DistGraphServer, @@ -38,12 +40,33 @@ import struct +def _verify_dist_graph_server_dgl(g): + # verify dtype of underlying graph + cg = g.client_g + for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items(): + if k in cg.ndata: + assert ( + F.dtype(cg.ndata[k]) == dtype + ), "Data type of {} in ndata should be {}.".format(k, dtype) + if k in cg.edata: + assert ( + F.dtype(cg.edata[k]) == dtype + ), "Data type of {} in edata should be {}.".format(k, dtype) + + +def _verify_dist_graph_server_graphbolt(g): + graph = g.client_g + assert isinstance(graph, gb.FusedCSCSamplingGraph) + # [Rui][TODO] verify dtype of underlying graph. + + def run_server( graph_name, server_id, server_count, num_clients, shared_mem, + use_graphbolt=False, ): g = DistGraphServer( server_id, @@ -53,19 +76,15 @@ def run_server( "/tmp/dist_graph/{}.json".format(graph_name), disable_shared_mem=not shared_mem, graph_format=["csc", "coo"], + use_graphbolt=use_graphbolt, ) - print("start server", server_id) - # verify dtype of underlying graph - cg = g.client_g - for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items(): - if k in cg.ndata: - assert ( - F.dtype(cg.ndata[k]) == dtype - ), "Data type of {} in ndata should be {}.".format(k, dtype) - if k in cg.edata: - assert ( - F.dtype(cg.edata[k]) == dtype - ), "Data type of {} in edata should be {}.".format(k, dtype) + print(f"Starting server[{server_id}] with use_graphbolt={use_graphbolt}") + _verify = ( + _verify_dist_graph_server_graphbolt + if use_graphbolt + else _verify_dist_graph_server_dgl + ) + _verify(g) g.start() @@ -110,18 +129,26 @@ def check_dist_graph_empty(g, num_clients, num_nodes, num_edges): def run_client_empty( - graph_name, part_id, server_count, num_clients, num_nodes, num_edges + graph_name, + part_id, + server_count, + num_clients, + num_nodes, + num_edges, + use_graphbolt=False, ): os.environ["DGL_NUM_SERVER"] = str(server_count) dgl.distributed.initialize("kv_ip_config.txt") gpb, graph_name, _, _ = load_partition_book( "/tmp/dist_graph/{}.json".format(graph_name), part_id ) - g = DistGraph(graph_name, gpb=gpb) + g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt) check_dist_graph_empty(g, num_clients, num_nodes, num_edges) -def check_server_client_empty(shared_mem, num_servers, num_clients): +def check_server_client_empty( + shared_mem, num_servers, num_clients, use_graphbolt=False +): prepare_dist(num_servers) g = create_random_graph(10000) @@ -129,6 +156,9 @@ def check_server_client_empty(shared_mem, num_servers, num_clients): num_parts = 1 graph_name = "dist_graph_test_1" partition_graph(g, graph_name, num_parts, "/tmp/dist_graph") + if use_graphbolt: + part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json") + dgl_partition_to_graphbolt(part_config) # let's just test on one partition for now. # We cannot run multiple servers and clients on the same machine. @@ -137,7 +167,14 @@ def check_server_client_empty(shared_mem, num_servers, num_clients): for serv_id in range(num_servers): p = ctx.Process( target=run_server, - args=(graph_name, serv_id, num_servers, num_clients, shared_mem), + args=( + graph_name, + serv_id, + num_servers, + num_clients, + shared_mem, + use_graphbolt, + ), ) serv_ps.append(p) p.start() @@ -154,6 +191,7 @@ def check_server_client_empty(shared_mem, num_servers, num_clients): num_clients, g.num_nodes(), g.num_edges(), + use_graphbolt, ), ) p.start() @@ -178,6 +216,7 @@ def run_client( num_nodes, num_edges, group_id, + use_graphbolt=False, ): os.environ["DGL_NUM_SERVER"] = str(server_count) os.environ["DGL_GROUP_ID"] = str(group_id) @@ -185,8 +224,10 @@ def run_client( gpb, graph_name, _, _ = load_partition_book( "/tmp/dist_graph/{}.json".format(graph_name), part_id ) - g = DistGraph(graph_name, gpb=gpb) - check_dist_graph(g, num_clients, num_nodes, num_edges) + g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt) + check_dist_graph( + g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt + ) def run_emb_client( @@ -270,14 +311,20 @@ def check_dist_optim_store(rank, num_nodes, optimizer_states, save): def run_client_hierarchy( - graph_name, part_id, server_count, node_mask, edge_mask, return_dict + graph_name, + part_id, + server_count, + node_mask, + edge_mask, + return_dict, + use_graphbolt=False, ): os.environ["DGL_NUM_SERVER"] = str(server_count) dgl.distributed.initialize("kv_ip_config.txt") gpb, graph_name, _, _ = load_partition_book( "/tmp/dist_graph/{}.json".format(graph_name), part_id ) - g = DistGraph(graph_name, gpb=gpb) + g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt) node_mask = F.tensor(node_mask) edge_mask = F.tensor(edge_mask) nodes = node_split( @@ -355,7 +402,7 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges): sys.exit(-1) -def check_dist_graph(g, num_clients, num_nodes, num_edges): +def check_dist_graph(g, num_clients, num_nodes, num_edges, use_graphbolt=False): # Test API assert g.num_nodes() == num_nodes assert g.num_edges() == num_edges @@ -373,9 +420,15 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges): assert np.all(F.asnumpy(feats == eids)) # Test edge_subgraph - sg = g.edge_subgraph(eids) - assert sg.num_edges() == len(eids) - assert F.array_equal(sg.edata[dgl.EID], eids) + if use_graphbolt: + with pytest.raises( + AssertionError, match="find_edges is not supported in GraphBolt." + ): + g.edge_subgraph(eids) + else: + sg = g.edge_subgraph(eids) + assert sg.num_edges() == len(eids) + assert F.array_equal(sg.edata[dgl.EID], eids) # Test init node data new_shape = (g.num_nodes(), 2) @@ -522,7 +575,9 @@ def check_dist_emb_server_client( print("clients have terminated") -def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): +def check_server_client( + shared_mem, num_servers, num_clients, num_groups=1, use_graphbolt=False +): prepare_dist(num_servers) g = create_random_graph(10000) @@ -532,6 +587,9 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1) g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1) partition_graph(g, graph_name, num_parts, "/tmp/dist_graph") + if use_graphbolt: + part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json") + dgl_partition_to_graphbolt(part_config) # let's just test on one partition for now. # We cannot run multiple servers and clients on the same machine. @@ -546,6 +604,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): num_servers, num_clients, shared_mem, + use_graphbolt, ), ) serv_ps.append(p) @@ -566,6 +625,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): g.num_nodes(), g.num_edges(), group_id, + use_graphbolt, ), ) p.start() @@ -582,7 +642,12 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): print("clients have terminated") -def check_server_client_hierarchy(shared_mem, num_servers, num_clients): +def check_server_client_hierarchy( + shared_mem, num_servers, num_clients, use_graphbolt=False +): + if num_clients == 1: + # skip this test if there is only one client. + return prepare_dist(num_servers) g = create_random_graph(10000) @@ -598,6 +663,9 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients): "/tmp/dist_graph", num_trainers_per_machine=num_clients, ) + if use_graphbolt: + part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json") + dgl_partition_to_graphbolt(part_config) # let's just test on one partition for now. # We cannot run multiple servers and clients on the same machine. @@ -606,7 +674,14 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients): for serv_id in range(num_servers): p = ctx.Process( target=run_server, - args=(graph_name, serv_id, num_servers, num_clients, shared_mem), + args=( + graph_name, + serv_id, + num_servers, + num_clients, + shared_mem, + use_graphbolt, + ), ) serv_ps.append(p) p.start() @@ -633,6 +708,7 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients): node_mask, edge_mask, return_dict, + use_graphbolt, ), ) p.start() @@ -658,15 +734,23 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients): def run_client_hetero( - graph_name, part_id, server_count, num_clients, num_nodes, num_edges + graph_name, + part_id, + server_count, + num_clients, + num_nodes, + num_edges, + use_graphbolt=False, ): os.environ["DGL_NUM_SERVER"] = str(server_count) dgl.distributed.initialize("kv_ip_config.txt") gpb, graph_name, _, _ = load_partition_book( "/tmp/dist_graph/{}.json".format(graph_name), part_id ) - g = DistGraph(graph_name, gpb=gpb) - check_dist_graph_hetero(g, num_clients, num_nodes, num_edges) + g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt) + check_dist_graph_hetero( + g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt + ) def create_random_hetero(): @@ -701,7 +785,9 @@ def create_random_hetero(): return g -def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges): +def check_dist_graph_hetero( + g, num_clients, num_nodes, num_edges, use_graphbolt=False +): # Test API for ntype in num_nodes: assert ntype in g.ntypes @@ -754,12 +840,18 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges): assert expect_except # Test edge_subgraph - sg = g.edge_subgraph({"r1": eids}) - assert sg.num_edges() == len(eids) - assert F.array_equal(sg.edata[dgl.EID], eids) - sg = g.edge_subgraph({("n1", "r1", "n2"): eids}) - assert sg.num_edges() == len(eids) - assert F.array_equal(sg.edata[dgl.EID], eids) + if use_graphbolt: + with pytest.raises( + AssertionError, match="find_edges is not supported in GraphBolt." + ): + g.edge_subgraph({"r1": eids}) + else: + sg = g.edge_subgraph({"r1": eids}) + assert sg.num_edges() == len(eids) + assert F.array_equal(sg.edata[dgl.EID], eids) + sg = g.edge_subgraph({("n1", "r1", "n2"): eids}) + assert sg.num_edges() == len(eids) + assert F.array_equal(sg.edata[dgl.EID], eids) # Test init node data new_shape = (g.num_nodes("n1"), 2) @@ -827,7 +919,9 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges): print("end") -def check_server_client_hetero(shared_mem, num_servers, num_clients): +def check_server_client_hetero( + shared_mem, num_servers, num_clients, use_graphbolt=False +): prepare_dist(num_servers) g = create_random_hetero() @@ -835,6 +929,9 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients): num_parts = 1 graph_name = "dist_graph_test_3" partition_graph(g, graph_name, num_parts, "/tmp/dist_graph") + if use_graphbolt: + part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json") + dgl_partition_to_graphbolt(part_config) # let's just test on one partition for now. # We cannot run multiple servers and clients on the same machine. @@ -843,7 +940,14 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients): for serv_id in range(num_servers): p = ctx.Process( target=run_server, - args=(graph_name, serv_id, num_servers, num_clients, shared_mem), + args=( + graph_name, + serv_id, + num_servers, + num_clients, + shared_mem, + use_graphbolt, + ), ) serv_ps.append(p) p.start() @@ -862,6 +966,7 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients): num_clients, num_nodes, num_edges, + use_graphbolt, ), ) p.start() @@ -886,21 +991,23 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients): @unittest.skipIf( dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support" ) -def test_server_client(): +@pytest.mark.parametrize("shared_mem", [True]) +@pytest.mark.parametrize("num_servers", [1]) +@pytest.mark.parametrize("num_clients", [1, 4]) +@pytest.mark.parametrize("use_graphbolt", [True, False]) +def test_server_client(shared_mem, num_servers, num_clients, use_graphbolt): reset_envs() os.environ["DGL_DIST_MODE"] = "distributed" - check_server_client_hierarchy(False, 1, 4) - check_server_client_empty(True, 1, 1) - check_server_client_hetero(True, 1, 1) - check_server_client_hetero(False, 1, 1) - check_server_client(True, 1, 1) - check_server_client(False, 1, 1) - # [TODO][Rhett] Tests for multiple groups may fail sometimes and - # root cause is unknown. Let's disable them for now. - # check_server_client(True, 2, 2) - # check_server_client(True, 1, 1, 2) - # check_server_client(False, 1, 1, 2) - # check_server_client(True, 2, 2, 2) + # [Rui] + # 1. `disable_shared_mem=False` is not supported yet. Skip it. + # 2. `num_servers` > 1 does not work on single machine. Skip it. + for func in [ + check_server_client, + check_server_client_hetero, + check_server_client_empty, + check_server_client_hierarchy, + ]: + func(shared_mem, num_servers, num_clients, use_graphbolt=use_graphbolt) @unittest.skip(reason="Skip due to glitch in CI")