From ee8b7b39ce19d6d6e0e97c48d2973a4ec586dd8d Mon Sep 17 00:00:00 2001 From: Rhett Ying <85214957+Rhett-Ying@users.noreply.github.com> Date: Tue, 6 Feb 2024 11:22:38 +0800 Subject: [PATCH] [DistGB] enable GB sampling on heterograph (#7087) --- python/dgl/distributed/graph_services.py | 36 +++- .../distributed/test_distributed_sampling.py | 200 +++++++++++++++--- 2 files changed, 196 insertions(+), 40 deletions(-) diff --git a/python/dgl/distributed/graph_services.py b/python/dgl/distributed/graph_services.py index 58eeb6de1f89..9188a38675a9 100644 --- a/python/dgl/distributed/graph_services.py +++ b/python/dgl/distributed/graph_services.py @@ -1,4 +1,5 @@ """A set of graph services of getting subgraphs from DistGraph""" +import os from collections import namedtuple import numpy as np @@ -708,24 +709,47 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb): idtype=g.idtype, ) - etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID]) - src, dst = frontier.edges() + # For DGL partitions, the global edge IDs are always stored in the edata. + # For GraphBolt partitions, the edge type IDs are always stored in the + # edata. As for the edge IDs, they are stored in the edata if the graph is + # partitioned with `store_eids=True`. Otherwise, the edge IDs are not + # stored. + etype_ids, type_wise_eids = ( + gpb.map_to_per_etype(frontier.edata[EID]) + if EID in frontier.edata + else (frontier.edata[ETYPE], None) + ) etype_ids, idx = F.sort_1d(etype_ids) + if type_wise_eids is not None: + type_wise_eids = F.gather_row(type_wise_eids, idx) + + # Sort the edges by their edge types. + src, dst = frontier.edges() src, dst = F.gather_row(src, idx), F.gather_row(dst, idx) - eid = F.gather_row(frontier.edata[EID], idx) - _, src = gpb.map_to_per_ntype(src) - _, dst = gpb.map_to_per_ntype(dst) + src_ntype_ids, src = gpb.map_to_per_ntype(src) + dst_ntype_ids, dst = gpb.map_to_per_ntype(dst) data_dict = dict() edge_ids = {} for etid, etype in enumerate(g.canonical_etypes): + src_ntype, _, dst_ntype = etype + src_ntype_id = g.get_ntype_id(src_ntype) + dst_ntype_id = g.get_ntype_id(dst_ntype) type_idx = etype_ids == etid if F.sum(type_idx, 0) > 0: data_dict[etype] = ( F.boolean_mask(src, type_idx), F.boolean_mask(dst, type_idx), ) - edge_ids[etype] = F.boolean_mask(eid, type_idx) + if "DGL_DIST_DEBUG" in os.environ: + assert torch.all( + src_ntype_id == src_ntype_ids[type_idx] + ), "source ntype is is not expected." + assert torch.all( + dst_ntype_id == dst_ntype_ids[type_idx] + ), "destination ntype is is not expected." + if type_wise_eids is not None: + edge_ids[etype] = F.boolean_mask(type_wise_eids, type_idx) hg = heterograph( data_dict, {ntype: g.num_nodes(ntype) for ntype in g.ntypes}, diff --git a/tests/distributed/test_distributed_sampling.py b/tests/distributed/test_distributed_sampling.py index 0795d4a03d25..eec8f51dbaa4 100644 --- a/tests/distributed/test_distributed_sampling.py +++ b/tests/distributed/test_distributed_sampling.py @@ -91,6 +91,9 @@ def start_sample_client_shuffle( dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt ) + assert ( + dgl.ETYPE not in sampled_graph.edata + ), "Etype should not be in homogeneous sampled graph." src, dst = sampled_graph.edges() src = orig_nid[src] dst = orig_nid[dst] @@ -460,23 +463,37 @@ def check_rpc_sampling_shuffle( assert p.exitcode == 0 -def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes): +def start_hetero_sample_client( + rank, + tmpdir, + disable_shared_mem, + nodes, + use_graphbolt=False, + return_eids=False, +): gpb = None if disable_shared_mem: _, _, _, gpb, _, _, _ = load_partition( tmpdir / "test_sampling.json", rank ) dgl.distributed.initialize("rpc_ip_config.txt") - dist_graph = DistGraph("test_sampling", gpb=gpb) + dist_graph = DistGraph( + "test_sampling", gpb=gpb, use_graphbolt=use_graphbolt + ) assert "feat" in dist_graph.nodes["n1"].data assert "feat" not in dist_graph.nodes["n2"].data assert "feat" not in dist_graph.nodes["n3"].data if gpb is None: gpb = dist_graph.get_partition_book() try: - sampled_graph = sample_neighbors(dist_graph, nodes, 3) + # Enable santity check in distributed sampling. + os.environ["DGL_DIST_DEBUG"] = "1" + sampled_graph = sample_neighbors( + dist_graph, nodes, 3, use_graphbolt=use_graphbolt + ) block = dgl.to_block(sampled_graph, nodes) - block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] + if not use_graphbolt or return_eids: + block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] except Exception as e: print(traceback.format_exc()) block = None @@ -528,7 +545,9 @@ def start_hetero_etype_sample_client( return block, gpb -def check_rpc_hetero_sampling_shuffle(tmpdir, num_server): +def check_rpc_hetero_sampling_shuffle( + tmpdir, num_server, use_graphbolt=False, return_eids=False +): generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = create_random_hetero() @@ -543,6 +562,8 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server): num_hops=num_hops, part_method="metis", return_mapping=True, + use_graphbolt=use_graphbolt, + store_eids=return_eids, ) pserver_list = [] @@ -550,16 +571,27 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server): for i in range(num_server): p = ctx.Process( target=start_server, - args=(i, tmpdir, num_server > 1, "test_sampling"), + args=( + i, + tmpdir, + num_server > 1, + "test_sampling", + ["csc", "coo"], + use_graphbolt, + ), ) p.start() time.sleep(1) pserver_list.append(p) block, gpb = start_hetero_sample_client( - 0, tmpdir, num_server > 1, nodes={"n3": [0, 10, 99, 66, 124, 208]} + 0, + tmpdir, + num_server > 1, + nodes={"n3": [0, 10, 99, 66, 124, 208]}, + use_graphbolt=use_graphbolt, + return_eids=return_eids, ) - print("Done sampling") for p in pserver_list: p.join() assert p.exitcode == 0 @@ -570,10 +602,17 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server): # These are global Ids after shuffling. shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src) shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst) - shuffled_eid = block.edges[etype].data[dgl.EID] - orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src)) orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst)) + + assert np.all( + F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype)) + ) + + if use_graphbolt and not return_eids: + continue + + shuffled_eid = block.edges[etype].data[dgl.EID] orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid)) # Check the node Ids and edge Ids. @@ -592,7 +631,9 @@ def get_degrees(g, nids, ntype): return deg -def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server): +def check_rpc_hetero_sampling_empty_shuffle( + tmpdir, num_server, use_graphbolt=False, return_eids=False +): generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = create_random_hetero(empty=True) @@ -607,6 +648,8 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server): num_hops=num_hops, part_method="metis", return_mapping=True, + use_graphbolt=use_graphbolt, + store_eids=return_eids, ) pserver_list = [] @@ -614,7 +657,14 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server): for i in range(num_server): p = ctx.Process( target=start_server, - args=(i, tmpdir, num_server > 1, "test_sampling"), + args=( + i, + tmpdir, + num_server > 1, + "test_sampling", + ["csc", "coo"], + use_graphbolt, + ), ) p.start() time.sleep(1) @@ -623,9 +673,13 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server): deg = get_degrees(g, orig_nids["n3"], "n3") empty_nids = F.nonzero_1d(deg == 0) block, gpb = start_hetero_sample_client( - 0, tmpdir, num_server > 1, nodes={"n3": empty_nids} + 0, + tmpdir, + num_server > 1, + nodes={"n3": empty_nids}, + use_graphbolt=use_graphbolt, + return_eids=return_eids, ) - print("Done sampling") for p in pserver_list: p.join() assert p.exitcode == 0 @@ -759,22 +813,36 @@ def create_random_bipartite(): return g -def start_bipartite_sample_client(rank, tmpdir, disable_shared_mem, nodes): +def start_bipartite_sample_client( + rank, + tmpdir, + disable_shared_mem, + nodes, + use_graphbolt=False, + return_eids=False, +): gpb = None if disable_shared_mem: _, _, _, gpb, _, _, _ = load_partition( tmpdir / "test_sampling.json", rank ) dgl.distributed.initialize("rpc_ip_config.txt") - dist_graph = DistGraph("test_sampling", gpb=gpb) + dist_graph = DistGraph( + "test_sampling", gpb=gpb, use_graphbolt=use_graphbolt + ) assert "feat" in dist_graph.nodes["user"].data assert "feat" in dist_graph.nodes["game"].data if gpb is None: gpb = dist_graph.get_partition_book() - sampled_graph = sample_neighbors(dist_graph, nodes, 3) + # Enable santity check in distributed sampling. + os.environ["DGL_DIST_DEBUG"] = "1" + sampled_graph = sample_neighbors( + dist_graph, nodes, 3, use_graphbolt=use_graphbolt + ) block = dgl.to_block(sampled_graph, nodes) if sampled_graph.num_edges() > 0: - block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] + if not use_graphbolt or return_eids: + block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] dgl.distributed.exit_client() return block, gpb @@ -812,7 +880,9 @@ def start_bipartite_etype_sample_client( return block, gpb -def check_rpc_bipartite_sampling_empty(tmpdir, num_server): +def check_rpc_bipartite_sampling_empty( + tmpdir, num_server, use_graphbolt=False, return_eids=False +): """sample on bipartite via sample_neighbors() which yields empty sample results""" generate_ip_config("rpc_ip_config.txt", num_server, num_server) @@ -828,6 +898,8 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server): num_hops=num_hops, part_method="metis", return_mapping=True, + use_graphbolt=use_graphbolt, + store_eids=return_eids, ) pserver_list = [] @@ -835,7 +907,14 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server): for i in range(num_server): p = ctx.Process( target=start_server, - args=(i, tmpdir, num_server > 1, "test_sampling"), + args=( + i, + tmpdir, + num_server > 1, + "test_sampling", + ["csc", "coo"], + use_graphbolt, + ), ) p.start() time.sleep(1) @@ -844,7 +923,12 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server): deg = get_degrees(g, orig_nids["game"], "game") empty_nids = F.nonzero_1d(deg == 0) block, _ = start_bipartite_sample_client( - 0, tmpdir, num_server > 1, nodes={"game": empty_nids, "user": [1]} + 0, + tmpdir, + num_server > 1, + nodes={"game": empty_nids, "user": [1]}, + use_graphbolt=use_graphbolt, + return_eids=return_eids, ) print("Done sampling") @@ -856,7 +940,9 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server): assert len(block.etypes) == len(g.etypes) -def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): +def check_rpc_bipartite_sampling_shuffle( + tmpdir, num_server, use_graphbolt=False, return_eids=False +): """sample on bipartite via sample_neighbors() which yields non-empty sample results""" generate_ip_config("rpc_ip_config.txt", num_server, num_server) @@ -872,6 +958,8 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): num_hops=num_hops, part_method="metis", return_mapping=True, + use_graphbolt=use_graphbolt, + store_eids=return_eids, ) pserver_list = [] @@ -879,7 +967,14 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): for i in range(num_server): p = ctx.Process( target=start_server, - args=(i, tmpdir, num_server > 1, "test_sampling"), + args=( + i, + tmpdir, + num_server > 1, + "test_sampling", + ["csc", "coo"], + use_graphbolt, + ), ) p.start() time.sleep(1) @@ -888,7 +983,12 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): deg = get_degrees(g, orig_nid_map["game"], "game") nids = F.nonzero_1d(deg > 0) block, gpb = start_bipartite_sample_client( - 0, tmpdir, num_server > 1, nodes={"game": nids, "user": [0]} + 0, + tmpdir, + num_server > 1, + nodes={"game": nids, "user": [0]}, + use_graphbolt=use_graphbolt, + return_eids=return_eids, ) print("Done sampling") for p in pserver_list: @@ -901,10 +1001,16 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server): # These are global Ids after shuffling. shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src) shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst) - shuffled_eid = block.edges[etype].data[dgl.EID] - orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src)) orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst)) + assert np.all( + F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype)) + ) + + if use_graphbolt and not return_eids: + continue + + shuffled_eid = block.edges[etype].data[dgl.EID] orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid)) # Check the node Ids and edge Ids. @@ -1032,19 +1138,35 @@ def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids): @pytest.mark.parametrize("num_server", [1]) -def test_rpc_hetero_sampling_shuffle(num_server): +@pytest.mark.parametrize("use_graphbolt,", [False, True]) +@pytest.mark.parametrize("return_eids", [False, True]) +def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids): reset_envs() os.environ["DGL_DIST_MODE"] = "distributed" with tempfile.TemporaryDirectory() as tmpdirname: - check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server) + check_rpc_hetero_sampling_shuffle( + Path(tmpdirname), + num_server, + use_graphbolt=use_graphbolt, + return_eids=return_eids, + ) @pytest.mark.parametrize("num_server", [1]) -def test_rpc_hetero_sampling_empty_shuffle(num_server): +@pytest.mark.parametrize("use_graphbolt", [False, True]) +@pytest.mark.parametrize("return_eids", [False, True]) +def test_rpc_hetero_sampling_empty_shuffle( + num_server, use_graphbolt, return_eids +): reset_envs() os.environ["DGL_DIST_MODE"] = "distributed" with tempfile.TemporaryDirectory() as tmpdirname: - check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server) + check_rpc_hetero_sampling_empty_shuffle( + Path(tmpdirname), + num_server, + use_graphbolt=use_graphbolt, + return_eids=return_eids, + ) @pytest.mark.parametrize("num_server", [1]) @@ -1071,19 +1193,29 @@ def test_rpc_hetero_etype_sampling_empty_shuffle(num_server): @pytest.mark.parametrize("num_server", [1]) -def test_rpc_bipartite_sampling_empty_shuffle(num_server): +@pytest.mark.parametrize("use_graphbolt", [False, True]) +@pytest.mark.parametrize("return_eids", [False, True]) +def test_rpc_bipartite_sampling_empty_shuffle( + num_server, use_graphbolt, return_eids +): reset_envs() os.environ["DGL_DIST_MODE"] = "distributed" with tempfile.TemporaryDirectory() as tmpdirname: - check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server) + check_rpc_bipartite_sampling_empty( + Path(tmpdirname), num_server, use_graphbolt, return_eids + ) @pytest.mark.parametrize("num_server", [1]) -def test_rpc_bipartite_sampling_shuffle(num_server): +@pytest.mark.parametrize("use_graphbolt", [False, True]) +@pytest.mark.parametrize("return_eids", [False, True]) +def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids): reset_envs() os.environ["DGL_DIST_MODE"] = "distributed" with tempfile.TemporaryDirectory() as tmpdirname: - check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server) + check_rpc_bipartite_sampling_shuffle( + Path(tmpdirname), num_server, use_graphbolt, return_eids + ) @pytest.mark.parametrize("num_server", [1])