Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][Doc] Updated MultiGPU tutorial #7126

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 62 additions & 71 deletions tutorials/multi/2_node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,7 @@ def __init__(self, in_size, hidden_size, out_size):
self.hidden_size = hidden_size
self.out_size = out_size
# Set the dtype for the layers manually.
self.set_layer_dtype(torch.float32)

def set_layer_dtype(self, dtype):
for layer in self.layers:
for param in layer.parameters():
param.data = param.data.to(dtype)
self.float()

def forward(self, blocks, x):
hidden_x = x
Expand Down Expand Up @@ -105,22 +100,38 @@ def create_dataloader(
features,
itemset,
device,
drop_last=False,
shuffle=True,
drop_uneven_inputs=False,
is_train,
):
datapipe = gb.DistributedItemSampler(
item_set=itemset,
batch_size=1024,
drop_last=drop_last,
shuffle=shuffle,
drop_uneven_inputs=drop_uneven_inputs,
drop_last=is_train,
shuffle=is_train,
drop_uneven_inputs=is_train,
)
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])
# Now that we have moved to device, sample_neighbor and fetch_feature steps
# will be executed on GPUs.
datapipe = datapipe.sample_neighbor(graph, [10, 10, 10])
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
return dataloader
return gb.DataLoader(datapipe)


def weighted_reduce(tensor, weight, dst=0):
########################################################################
# (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
# obtain overall average values.
#
# `torch.distributed.reduce` is used to reduce tensors from all the
# sub-processes to a specified process, ReduceOp.SUM is used by default.
#
# Because the GPUs may have differing numbers of processed items, we
# perform a weighted mean to calculate the exact loss and accuracy.
########################################################################
dist.reduce(tensor=tensor, dst=dst)
weight = torch.tensor(weight, device=tensor.device)
dist.reduce(tensor=weight, dst=dst)
return tensor / weight


######################################################################
Expand All @@ -140,15 +151,11 @@ def evaluate(rank, model, graph, features, itemset, num_classes, device):
graph,
features,
itemset,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
device=device,
device,
is_train=False,
)

for step, data in (
tqdm.tqdm(enumerate(dataloader)) if rank == 0 else enumerate(dataloader)
):
for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader:
blocks = data.blocks
x = data.node_features["feat"]
y.append(data.labels)
Expand All @@ -161,7 +168,7 @@ def evaluate(rank, model, graph, features, itemset, num_classes, device):
num_classes=num_classes,
)

return res.to(device)
return res.to(device), sum(y_i.size(0) for y_i in y)


######################################################################
Expand Down Expand Up @@ -196,22 +203,17 @@ def train(
features,
train_set,
device,
drop_last=False,
shuffle=True,
drop_uneven_inputs=False,
is_train=True,
)

for epoch in range(5):
epoch_start = time.time()

model.train()
total_loss = torch.tensor(0, dtype=torch.float).to(device)
total_loss = torch.tensor(0, dtype=torch.float, device=device)
num_train_items = 0
with Join([model]):
for step, data in (
tqdm.tqdm(enumerate(dataloader))
if rank == 0
else enumerate(dataloader)
):
for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader:
# The input features are from the source nodes in the first
# layer's computation graph.
x = data.node_features["feat"]
Expand All @@ -231,40 +233,31 @@ def train(
loss.backward()
optimizer.step()

total_loss += loss
total_loss += loss * y.size(0)
num_train_items += y.size(0)

# Evaluate the model.
if rank == 0:
print("Validating...")
acc = (
evaluate(
rank,
model,
graph,
features,
valid_set,
num_classes,
device,
)
/ world_size
acc, num_val_items = evaluate(
rank,
model,
graph,
features,
valid_set,
num_classes,
device,
)
########################################################################
# (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
# obtain overall average values.
#
# `torch.distributed.reduce` is used to reduce tensors from all the
# sub-processes to a specified process, ReduceOp.SUM is used by default.
########################################################################
dist.reduce(tensor=acc, dst=0)
total_loss /= step + 1
dist.reduce(tensor=total_loss, dst=0)
dist.barrier()
total_loss = weighted_reduce(total_loss, num_train_items)
acc = weighted_reduce(acc * num_val_items, num_val_items)

# We synchronize before measuring the epoch time.
torch.cuda.synchronize()
epoch_end = time.time()
if rank == 0:
print(
f"Epoch {epoch:05d} | "
f"Average Loss {total_loss.item() / world_size:.4f} | "
f"Average Loss {total_loss.item():.4f} | "
f"Accuracy {acc.item():.4f} | "
f"Time {epoch_end - epoch_start:.4f}"
)
Expand Down Expand Up @@ -292,8 +285,9 @@ def run(rank, world_size, devices, dataset):
rank=rank,
)

graph = dataset.graph
features = dataset.feature
# Pin the graph and features in-place to enable GPU access.
graph = dataset.graph.pin_memory_()
features = dataset.feature.pin_memory_()
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
num_classes = dataset.tasks[0].metadata["num_classes"]
Expand Down Expand Up @@ -325,20 +319,17 @@ def run(rank, world_size, devices, dataset):
if rank == 0:
print("Testing...")
test_set = dataset.tasks[0].test_set
test_acc = (
evaluate(
rank,
model,
graph,
features,
itemset=test_set,
num_classes=num_classes,
device=device,
)
/ world_size
test_acc, num_test_items = evaluate(
rank,
model,
graph,
features,
itemset=test_set,
num_classes=num_classes,
device=device,
)
dist.reduce(tensor=test_acc, dst=0)
dist.barrier()
test_acc = weighted_reduce(test_acc * num_test_items, num_test_items)

if rank == 0:
print(f"Test Accuracy {test_acc.item():.4f}")

Expand Down
Loading