Skip to content

Commit

Permalink
[DistGB] add verify logic for GraphBolt partitions (#7031)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Jan 30, 2024
1 parent cda8b38 commit af87038
Showing 1 changed file with 66 additions and 3 deletions.
69 changes: 66 additions & 3 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .. import backend as F, graphbolt as gb
from ..base import dgl_warning, DGLError, EID, ETYPE, NID, NTYPE
from ..convert import to_homogeneous
from ..convert import heterograph, to_homogeneous
from ..data.utils import load_graphs, load_tensors, save_graphs, save_tensors
from ..partition import (
get_peak_mem,
Expand Down Expand Up @@ -190,8 +190,71 @@ def _verify_dgl_partition(graph, part_id, gpb, ntypes, etypes):

def _verify_graphbolt_partition(graph, part_id, gpb, ntypes, etypes):
"""Verify the partition of a GraphBolt graph."""
# [Rui][TODO]
_, _, _, _, _ = graph, part_id, gpb, ntypes, etypes
required_ndata_fields = [NID]
required_edata_fields = [EID]
assert all(
field in graph.node_attributes for field in required_ndata_fields
), "the partition graph should contain node mapping to global node ID."
assert all(
field in graph.edge_attributes for field in required_edata_fields
), "the partition graph should contain edge mapping to global edge ID."

num_nodes = graph.total_num_nodes
num_edges = graph.total_num_edges
local_src_ids = graph.indices
local_dst_ids = torch.repeat_interleave(
torch.arange(num_nodes), torch.diff(graph.csc_indptr)
)
global_src_ids = graph.node_attributes[NID][local_src_ids]
global_dst_ids = graph.node_attributes[NID][local_dst_ids]

etype_ids, type_wise_eids = gpb.map_to_per_etype(graph.edge_attributes[EID])
if graph.type_per_edge is not None:
assert torch.equal(etype_ids, graph.type_per_edge)
etype_ids, etype_ids_indices = torch.sort(etype_ids)
global_src_ids = global_src_ids[etype_ids_indices]
global_dst_ids = global_dst_ids[etype_ids_indices]
type_wise_eids = type_wise_eids[etype_ids_indices]

src_ntype_ids, src_type_wise_nids = gpb.map_to_per_ntype(global_src_ids)
dst_ntype_ids, dst_type_wise_nids = gpb.map_to_per_ntype(global_dst_ids)

data_dict = dict()
edge_ids = dict()
for c_etype, etype_id in etypes.items():
idx = etype_ids == etype_id
src_ntype, etype, dst_ntype = c_etype
if idx.sum() == 0:
continue
actual_src_ntype_ids = src_ntype_ids[idx]
actual_dst_ntype_ids = dst_ntype_ids[idx]
expected_src_ntype_ids = ntypes[src_ntype]
expected_dst_ntype_ids = ntypes[dst_ntype]
assert all(actual_src_ntype_ids == expected_src_ntype_ids), (
f"Unexpected types of source nodes for {c_etype}. Expected: "
f"{expected_src_ntype_ids}, but got: {actual_src_ntype_ids}."
)
assert all(actual_dst_ntype_ids == expected_dst_ntype_ids), (
f"Unexpected types of destination nodes for {c_etype}. Expected: "
f"{expected_dst_ntype_ids}, but got: {actual_dst_ntype_ids}."
)
data_dict[c_etype] = (src_type_wise_nids[idx], dst_type_wise_nids[idx])
edge_ids[c_etype] = type_wise_eids[idx]

# Make sure node/edge IDs are not out of range.
hg = heterograph(
data_dict, {ntype: gpb._num_nodes(ntype) for ntype in ntypes}
)
for etype in edge_ids:
hg.edges[etype].data[EID] = edge_ids[etype]
assert all(
hg.num_edges(etype) == len(eids) for etype, eids in edge_ids.items()
), "The number of edges per etype in the partition graph is not correct."
assert num_edges == hg.num_edges(), (
f"The total number of edges in the partition graph is not correct. "
f"Expected: {num_edges}, but got: {hg.num_edges()}."
)
print(f"Partition {part_id} looks good!")


def load_partition(part_config, part_id, load_feats=True, use_graphbolt=False):
Expand Down

0 comments on commit af87038

Please sign in to comment.