Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][PyG] Modify PyG example with to_pyg_data #7123

Merged
merged 15 commits into from
Mar 1, 2024
218 changes: 131 additions & 87 deletions examples/sampling/pyg/node_classification.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""
This script demonstrates node classification with GraphSAGE on large graphs,
merging GraphBolt (GB) and PyTorch Geometric (PyG). GraphBolt efficiently manages
data loading for large datasets, crucial for mini-batch processing. Post data
loading, PyG's user-friendly framework takes over for training, showcasing seamless
integration with GraphBolt. This combination offers an efficient alternative to
traditional Deep Graph Library (DGL) methods, highlighting adaptability and
scalability in handling large-scale graph data for diverse real-world applications.


merging GraphBolt (GB) and PyTorch Geometric (PyG). GraphBolt efficiently
manages data loading for large datasets, crucial for mini-batch processing.
Post data loading, PyG's user-friendly framework takes over for training,
showcasing seamless integration with GraphBolt. This combination offers an
efficient alternative to traditional Deep Graph Library (DGL) methods,
highlighting adaptability and scalability in handling large-scale graph data
for diverse real-world applications.

Key Features:
- Implements the GraphSAGE model, a scalable GNN, for node classification on large graphs.
- Implements the GraphSAGE model, a scalable GNN, for node classification on
large graphs.
- Utilizes GraphBolt, an efficient framework for large-scale graph data processing.
- Integrates with PyTorch Geometric for building and training the GraphSAGE model.
- The script is well-documented, providing clear explanations at each step.
Expand Down Expand Up @@ -38,6 +38,8 @@
│ │
│ ├───> Forward and backward passes
│ │
│ ├───> Convert GraphBolt MiniBatch to PyG Data
│ │
│ └───> Parameters optimization
└───> Evaluate the model
Expand All @@ -56,6 +58,7 @@
import torch.nn.functional as F
import torchmetrics.functional as MF
from torch_geometric.nn import SAGEConv
from tqdm import tqdm


class GraphSAGE(torch.nn.Module):
Expand All @@ -67,6 +70,8 @@ class GraphSAGE(torch.nn.Module):
# - 'in_size', 'hidden_size', 'out_size' are the sizes of
# the input, hidden, and output features, respectively.
# - The forward method defines the computation performed at every call.
# - It's adopted from the official PyG example which can be found at
# https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_products_sage.py
#####################################################################
def __init__(self, in_size, hidden_size, out_size):
super(GraphSAGE, self).__init__()
Expand All @@ -75,41 +80,46 @@ def __init__(self, in_size, hidden_size, out_size):
self.layers.append(SAGEConv(hidden_size, hidden_size))
self.layers.append(SAGEConv(hidden_size, out_size))

def forward(self, blocks, x, device):
h = x
for i, (layer, block) in enumerate(zip(self.layers, blocks)):
src, dst = block.edges()
edge_index = torch.stack([src, dst], dim=0)
h_src, h_dst = h, h[: block.number_of_dst_nodes()]
h = layer((h_src, h_dst), edge_index)
if i != len(blocks) - 1:
h = F.relu(h)
return h
def forward(self, x, edge_index):
RamonZhou marked this conversation as resolved.
Show resolved Hide resolved
for i, layer in enumerate(self.layers):
x = layer(x, edge_index)
if i != len(self.layers) - 1:
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)
return x

def inference(self, args, dataloader, x_all, device):
"""Conduct layer-wise inference to get all the node embeddings."""
for i, layer in tqdm(enumerate(self.layers), "inference"):
xs = []
for minibatch in dataloader:
# Call `to_pyg_data` to convert GB Minibatch to PyG Data.
pyg_data = minibatch.to_pyg_data()
x = x_all[minibatch.node_ids()].to(device)
edge_index = pyg_data.edge_index
x = layer(x, edge_index)
x = x[: 4 * args.batch_size]
if i != len(self.layers) - 1:
x = x.relu()
xs.append(x.cpu())
x_all = torch.cat(xs, dim=0)
return x_all

def create_dataloader(dataset_set, graph, feature, device, is_train):
#####################################################################
# (HIGHLIGHT) Create a data loader for efficiently loading graph data.
#
# - 'ItemSampler' samples mini-batches of node IDs from the dataset.
# - 'sample_neighbor' performs neighbor sampling on the graph.
# - 'FeatureFetcher' fetches node features based on the sampled subgraph.
# - 'CopyTo' copies the fetched data to the specified device.

#####################################################################
# Create a datapipe for mini-batch sampling with a specific neighbor fanout.
# Here, [10, 10, 10] specifies the number of neighbors sampled for each node at each layer.
# We're using `sample_neighbor` for consistency with DGL's sampling API.
# Note: GraphBolt offers additional sampling methods, such as `sample_layer_neighbor`,
# which could provide further optimization and efficiency for GNN training.
# Users are encouraged to explore these advanced features for potentially improved performance.

def create_dataloader(
dataset_set, graph, feature, batch_size, fanout, device, job
):
# Initialize an ItemSampler to sample mini-batches from the dataset.
datapipe = gb.ItemSampler(
dataset_set, batch_size=1024, shuffle=is_train, drop_last=is_train
dataset_set,
batch_size=batch_size,
shuffle=(job == "train"),
drop_last=(job == "train"),
)
# Sample neighbors for each node in the mini-batch.
datapipe = datapipe.sample_neighbor(graph, [10, 10, 10])
datapipe = datapipe.sample_neighbor(
graph, fanout if job != "infer" else [-1]
)
# Fetch node features for the sampled subgraph.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
# Copy the data to the specified device.
Expand All @@ -120,59 +130,49 @@ def create_dataloader(dataset_set, graph, feature, device, is_train):
return dataloader


def train(model, dataloader, optimizer, criterion, device, num_classes):
#####################################################################
# (HIGHLIGHT) Train the model for one epoch.
#
# - Iterates over the data loader, fetching mini-batches of graph data.
# - For each mini-batch, it performs a forward pass, computes loss, and
# updates the model parameters.
# - The function returns the average loss and accuracy for the epoch.
#
# Parameters:
# model: The GraphSAGE model.
# dataloader: DataLoader that provides mini-batches of graph data.
# optimizer: Optimizer used for updating model parameters.
# criterion: Loss function used for training.
# device: The device (CPU/GPU) to run the training on.
#####################################################################

def train(model, dataloader, optimizer):
model.train() # Set the model to training mode
total_loss = 0 # Accumulator for the total loss
total_correct = 0 # Accumulator for the total number of correct predictions
total_samples = 0 # Accumulator for the total number of samples processed
num_batches = 0 # Counter for the number of mini-batches processed

for minibatch in dataloader:
node_features = minibatch.node_features["feat"]
labels = minibatch.labels
for _, minibatch in tqdm(enumerate(dataloader), "training"):
#####################################################################
# (HIGHLIGHT) Convert GraphBolt MiniBatch to PyG Data class.
#
# Call `MiniBatch.to_pyg_data()` and it will return a PyG Data class
# with necessary data and information.
#####################################################################
pyg_data = minibatch.to_pyg_data()
RamonZhou marked this conversation as resolved.
Show resolved Hide resolved

optimizer.zero_grad()
out = model(minibatch.blocks, node_features, device)
loss = criterion(out, labels)
total_loss += loss.item()
total_correct += MF.accuracy(
out, labels, task="multiclass", num_classes=num_classes
) * labels.size(0)
total_samples += labels.size(0)
out = model(pyg_data.x, pyg_data.edge_index)[: pyg_data.y.shape[0]]
y = pyg_data.y
loss = F.cross_entropy(out, y)
loss.backward()
optimizer.step()

total_loss += float(loss)
total_correct += int(out.argmax(dim=-1).eq(y).sum())
total_samples += y.shape[0]
num_batches += 1
avg_loss = total_loss / num_batches
avg_accuracy = total_correct / total_samples
return avg_loss, avg_accuracy


@torch.no_grad()
def evaluate(model, dataloader, device, num_classes):
def evaluate(model, dataloader, num_classes):
model.eval()
y_hats = []
ys = []
for minibatch in dataloader:
node_features = minibatch.node_features["feat"]
labels = minibatch.labels
out = model(minibatch.blocks, node_features, device)
for _, minibatch in tqdm(enumerate(dataloader), "evaluating"):
pyg_data = minibatch.to_pyg_data()
out = model(pyg_data.x, pyg_data.edge_index)[: pyg_data.y.shape[0]]
y = pyg_data.y
y_hats.append(out)
ys.append(labels)
ys.append(y)

return MF.accuracy(
torch.cat(y_hats),
Expand All @@ -182,52 +182,96 @@ def evaluate(model, dataloader, device, num_classes):
)


@torch.no_grad()
def layerwise_infer(
model, args, infer_dataloader, test_set, feature, num_classes, device
):
model.eval()
features = feature.read("node", None, "feat")
pred = model.inference(args, infer_dataloader, features, device)
pred = pred[test_set._items[0]]
label = test_set._items[1].to(pred.device)

return MF.accuracy(
pred,
label,
task="multiclass",
num_classes=num_classes,
)


def main():
parser = argparse.ArgumentParser(
description="Which dataset are you going to use?"
)
parser.add_argument(
"--dataset",
type=str,
default="ogbn-arxiv",
default="ogbn-products",
help='Name of the dataset to use (e.g., "ogbn-products", "ogbn-arxiv")',
)
parser.add_argument(
"--epochs", type=int, default=10, help="Number of training epochs."
)
parser.add_argument(
"--batch-size", type=int, default=1024, help="Batch size for training."
)
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_name = args.dataset
dataset = gb.BuiltinDataset(dataset_name).load()
graph = dataset.graph
feature = dataset.feature
feature = dataset.feature.to(device)
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
all_nodes_set = dataset.all_nodes_set
num_classes = dataset.tasks[0].metadata["num_classes"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataloader = create_dataloader(
train_set, graph, feature, device, is_train=True
train_set,
graph,
feature,
args.batch_size,
[10, 10, 10],
mfbalin marked this conversation as resolved.
Show resolved Hide resolved
device,
job="train",
)
valid_dataloader = create_dataloader(
valid_set, graph, feature, device, is_train=False
valid_set,
graph,
feature,
args.batch_size,
[10, 10, 10],
mfbalin marked this conversation as resolved.
Show resolved Hide resolved
device,
job="evaluate",
)
test_dataloader = create_dataloader(
test_set, graph, feature, device, is_train=False
infer_dataloader = create_dataloader(
all_nodes_set,
graph,
feature,
4 * args.batch_size,
[-1],
device,
job="infer",
)
in_channels = feature.size("node", None, "feat")[0]
hidden_channels = 128
hidden_channels = 256
model = GraphSAGE(in_channels, hidden_channels, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(10):
train_loss, train_accuracy = train(
model, train_dataloader, optimizer, criterion, device, num_classes
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
for epoch in range(args.epochs):
train_loss, train_accuracy = train(model, train_dataloader, optimizer)

valid_accuracy = evaluate(model, valid_dataloader, device, num_classes)
valid_accuracy = evaluate(model, valid_dataloader, num_classes)
print(
f"Epoch {epoch}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, "
f"Epoch {epoch}, Train Loss: {train_loss:.4f}, "
f"Train Accuracy: {train_accuracy:.4f}, "
f"Valid Accuracy: {valid_accuracy:.4f}"
)
test_accuracy = evaluate(model, test_dataloader, device, num_classes)
test_accuracy = layerwise_infer(
model, args, infer_dataloader, test_set, feature, num_classes, device
)
print(f"Test Accuracy: {test_accuracy:.4f}")


Expand Down
Loading