Skip to content

Commit

Permalink
[GraphBolt] Modify SubgraphSampler to support seeds. (#7049)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
  • Loading branch information
yxy235 and Ubuntu authored Feb 6, 2024
1 parent ee8b7b3 commit 845864d
Show file tree
Hide file tree
Showing 2 changed files with 1,395 additions and 20 deletions.
121 changes: 119 additions & 2 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
from typing import Dict

import torch
from torch.utils.data import functional_datapipe

from .base import etype_str_to_tuple
Expand Down Expand Up @@ -69,10 +70,16 @@ def _preprocess(minibatch):
seeds_timestamp = (
minibatch.timestamp if hasattr(minibatch, "timestamp") else None
)
elif minibatch.seeds is not None:
(
seeds,
seeds_timestamp,
minibatch.compacted_seeds,
) = SubgraphSampler._seeds_preprocess(minibatch)
else:
raise ValueError(
f"Invalid minibatch {minibatch}: Either `node_pairs` or "
"`seed_nodes` should have a value."
f"Invalid minibatch {minibatch}: One of `node_pairs`, "
"`seed_nodes` and `seeds` should have a value."
)
minibatch._seed_nodes = seeds
minibatch._seeds_timestamp = seeds_timestamp
Expand Down Expand Up @@ -226,6 +233,116 @@ def sampling_stages(self, datapipe):
"""
return datapipe.transform(self._sample)

@staticmethod
def _seeds_preprocess(minibatch):
"""Preprocess `seeds` in a minibatch to construct `unique_seeds`,
`node_timestamp` and `compacted_seeds` for further sampling. It
optionally incorporates timestamps for temporal graphs, organizing and
compacting seeds based on their types and timestamps.
Parameters
----------
minibatch: MiniBatch
The minibatch.
Returns
-------
unique_seeds: torch.Tensor or Dict[str, torch.Tensor]
A tensor or a dictionary of tensors representing the unique seeds.
In heterogeneous graphs, seeds are returned for each node type.
nodes_timestamp: None or a torch.Tensor or Dict[str, torch.Tensor]
Containing timestamps for each seed. This is only returned if
`minibatch` includes timestamps and the graph is temporal.
compacted_seeds: torch.tensor or a Dict[str, torch.Tensor]
Representation of compacted seeds corresponding to 'seeds', where
all node ids inside are compacted.
"""
use_timestamp = hasattr(minibatch, "timestamp")
seeds = minibatch.seeds
is_heterogeneous = isinstance(seeds, Dict)
if is_heterogeneous:
# Collect nodes from all types of input.
nodes = defaultdict(list)
nodes_timestamp = None
if use_timestamp:
nodes_timestamp = defaultdict(list)
for etype, pair in seeds.items():
assert pair.ndim == 1 or (
pair.ndim == 2 and pair.shape[1] == 2
), (
"Only tensor with shape 1*N and N*2 is "
+ f"supported now, but got {pair.shape}."
)
ntypes = etype[:].split(":")[::2]
pair = pair.view(pair.shape[0], -1)
if use_timestamp:
negative_ratio = (
pair.shape[0] // minibatch.timestamp[etype].shape[0] - 1
)
neg_timestamp = minibatch.timestamp[
etype
].repeat_interleave(negative_ratio)
for i, ntype in enumerate(ntypes):
nodes[ntype].append(pair[:, i])
if use_timestamp:
nodes_timestamp[ntype].append(
minibatch.timestamp[etype]
)
nodes_timestamp[ntype].append(neg_timestamp)
# Unique and compact the collected nodes.
if use_timestamp:
(
unique_seeds,
nodes_timestamp,
compacted,
) = compact_temporal_nodes(nodes, nodes_timestamp)
else:
unique_seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
compacted_seeds = {}
# Map back in same order as collect.
for etype, pair in seeds.items():
if pair.ndim == 1:
compacted_seeds[etype] = compacted[etype].pop(0)
else:
src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0)
compacted_seeds[etype] = torch.cat((src, dst)).view(2, -1).T
else:
# Collect nodes from all types of input.
nodes = [seeds.view(-1)]
nodes_timestamp = None
if use_timestamp:
# Timestamp for source and destination nodes are the same.
negative_ratio = (
seeds.shape[0] // minibatch.timestamp.shape[0] - 1
)
neg_timestamp = minibatch.timestamp.repeat_interleave(
negative_ratio
)
seeds_timestamp = torch.cat(
(minibatch.timestamp, neg_timestamp)
)
nodes_timestamp = [seeds_timestamp for _ in range(seeds.ndim)]
# Unique and compact the collected nodes.
if use_timestamp:
(
unique_seeds,
nodes_timestamp,
compacted,
) = compact_temporal_nodes(nodes, nodes_timestamp)
else:
unique_seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
# Map back in same order as collect.
compacted_seeds = compacted[0].view(seeds.shape)
return (
unique_seeds,
nodes_timestamp,
compacted_seeds,
)

def sample_subgraphs(self, seeds, seeds_timestamp):
"""Sample subgraphs from the given seeds, possibly with temporal constraints.
Expand Down
Loading

0 comments on commit 845864d

Please sign in to comment.