Skip to content

Commit

Permalink
Merge branch 'master' into ondisk_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov authored Feb 7, 2024
2 parents 10b3cb3 + 8cf5ad8 commit 47fdb4a
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 6 deletions.
7 changes: 5 additions & 2 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,11 @@ def _sample_neighbors_graphbolt(

# 3. Map local node IDs to global node IDs.
local_src = subgraph.indices
local_dst = torch.repeat_interleave(
subgraph.original_column_node_ids, torch.diff(subgraph.indptr)
local_dst = gb.expand_indptr(
subgraph.indptr,
dtype=local_src.dtype,
node_ids=subgraph.original_column_node_ids,
output_size=local_src.shape[0],
)
global_nid_mapping = g.node_attributes[NID]
global_src = global_nid_mapping[local_src]
Expand Down
5 changes: 2 additions & 3 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,10 @@ def _verify_graphbolt_partition(graph, part_id, gpb, ntypes, etypes):
field in graph.edge_attributes for field in required_edata_fields
), "the partition graph should contain edge mapping to global edge ID."

num_nodes = graph.total_num_nodes
num_edges = graph.total_num_edges
local_src_ids = graph.indices
local_dst_ids = torch.repeat_interleave(
torch.arange(num_nodes), torch.diff(graph.csc_indptr)
local_dst_ids = gb.expand_indptr(
graph.csc_indptr, dtype=local_src_ids.dtype, output_size=num_edges
)
global_src_ids = graph.node_attributes[NID][local_src_ids]
global_dst_ids = graph.node_attributes[NID][local_dst_ids]
Expand Down
46 changes: 45 additions & 1 deletion python/dgl/graphbolt/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import dgl
from dgl.utils import recursive_apply

from .base import etype_str_to_tuple
from .base import etype_str_to_tuple, expand_indptr
from .internal import get_attributes
from .sampled_subgraph import SampledSubgraph

Expand Down Expand Up @@ -474,6 +474,50 @@ def node_pairs_with_labels(self):
else:
return None

def to_pyg_data(self):
"""Construct a PyG Data from `MiniBatch`. This function only supports
node classification task on a homogeneous graph and the number of
features cannot be more than one.
"""
from torch_geometric.data import Data

if self.sampled_subgraphs is None:
edge_index = None
else:
col_nodes = []
row_nodes = []
for subgraph in self.sampled_subgraphs:
if subgraph is None:
continue
sampled_csc = subgraph.sampled_csc
indptr = sampled_csc.indptr
indices = sampled_csc.indices
expanded_indptr = expand_indptr(
indptr, dtype=indices.dtype, output_size=len(indices)
)
col_nodes.append(expanded_indptr)
row_nodes.append(indices)
col_nodes = torch.cat(col_nodes)
row_nodes = torch.cat(row_nodes)
edge_index = torch.unique(
torch.stack((col_nodes, row_nodes)), dim=1
)

if self.node_features is None:
node_features = None
else:
assert (
len(self.node_features) == 1
), "`to_pyg_data` only supports single feature homogeneous graph."
node_features = next(iter(self.node_features.values()))

pyg_data = Data(
x=node_features,
edge_index=edge_index,
y=self.labels,
)
return pyg_data

def to(self, device: torch.device): # pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection."""

Expand Down
83 changes: 83 additions & 0 deletions tests/python/pytorch/graphbolt/test_minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,3 +859,86 @@ def test_dgl_link_predication_hetero(mode):
minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype],
)


def test_to_pyg_data():
test_subgraph_a = gb.SampledSubgraphImpl(
sampled_csc=gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6]),
indices=torch.tensor([0, 1, 2, 2, 1, 2]),
),
original_column_node_ids=torch.tensor([10, 11, 12, 13]),
original_row_node_ids=torch.tensor([19, 20, 21, 22, 25, 30]),
original_edge_ids=torch.tensor([10, 11, 12, 13]),
)
test_subgraph_b = gb.SampledSubgraphImpl(
sampled_csc=gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3]),
indices=torch.tensor([1, 2, 0]),
),
original_row_node_ids=torch.tensor([10, 11, 12]),
original_edge_ids=torch.tensor([10, 15, 17]),
original_column_node_ids=torch.tensor([10, 11]),
)
expected_edge_index = torch.tensor(
[[0, 0, 1, 1, 1, 2, 2, 3], [0, 1, 0, 1, 2, 1, 2, 2]]
)
expected_node_features = torch.tensor([[1], [2], [3], [4]])
expected_labels = torch.tensor([0, 1])
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a, test_subgraph_b],
node_features={"feat": expected_node_features},
labels=expected_labels,
)
pyg_data = test_minibatch.to_pyg_data()
pyg_data.validate()
assert torch.equal(pyg_data.edge_index, expected_edge_index)
assert torch.equal(pyg_data.x, expected_node_features)
assert torch.equal(pyg_data.y, expected_labels)

# Test with sampled_csc as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=None,
node_features={"feat": expected_node_features},
labels=expected_labels,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.edge_index is None, "Edge index should be none."

# Test with node_features as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a],
node_features=None,
labels=expected_labels,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.x is None, "Node features should be None."

# Test with labels as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a],
node_features={"feat": expected_node_features},
labels=None,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.y is None, "Labels should be None."

# Test with multiple features.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a],
node_features={
"feat": expected_node_features,
"extra_feat": torch.tensor([[3], [4]]),
},
labels=expected_labels,
)
try:
pyg_data = test_minibatch.to_pyg_data()
assert (
pyg_data.x is None,
), "Multiple features case should raise an error."
except AssertionError as e:
assert (
str(e)
== "`to_pyg_data` only supports single feature homogeneous graph."
)

0 comments on commit 47fdb4a

Please sign in to comment.