Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] refine the node_classification examples. #7136

Merged
merged 5 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 48 additions & 33 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,23 @@ def create_dataloader(
return dataloader


def weighted_reduce(tensor, weight, dst=0):
########################################################################
# (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
# obtain overall average values.
#
# `torch.distributed.reduce` is used to reduce tensors from all the
# sub-processes to a specified process, ReduceOp.SUM is used by default.
#
# Because the GPUs may have differing numbers of processed items, we
# perform a weighted mean to calculate the exact loss and accuracy.
########################################################################
dist.reduce(tensor=tensor, dst=dst)
weight = torch.tensor(weight, device=tensor.device)
dist.reduce(tensor=weight, dst=dst)
return tensor / weight


@torch.no_grad()
def evaluate(rank, model, dataloader, num_classes, device):
model.eval()
Expand All @@ -164,11 +181,10 @@ def evaluate(rank, model, dataloader, num_classes, device):
num_classes=num_classes,
)

return res.to(device)
return res.to(device), sum(y_i.size(0) for y_i in y)


def train(
world_size,
rank,
args,
train_dataloader,
Expand All @@ -184,6 +200,7 @@ def train(

model.train()
total_loss = torch.tensor(0, dtype=torch.float, device=device)
num_train_items = 0
########################################################################
# (HIGHLIGHT) Use Join Context Manager to solve uneven input problem.
#
Expand All @@ -199,10 +216,8 @@ def train(
# uneven inputs.
########################################################################
with Join([model]):
for step, data in (
tqdm.tqdm(enumerate(train_dataloader))
if rank == 0
else enumerate(train_dataloader)
for data in (
tqdm.tqdm(train_dataloader) if rank == 0 else train_dataloader
):
# The input features are from the source nodes in the first
# layer's computation graph.
Expand All @@ -223,35 +238,31 @@ def train(
loss.backward()
optimizer.step()

total_loss += loss.detach()
total_loss += loss.detach() * y.size(0)
num_train_items += y.size(0)

# Evaluate the model.
if rank == 0:
print("Validating...")
acc = evaluate(
acc, num_val_items = evaluate(
rank,
model,
valid_dataloader,
num_classes,
device,
)
########################################################################
# (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
# obtain overall average values.
#
# `torch.distributed.reduce` is used to reduce tensors from all the
# sub-processes to a specified process, ReduceOp.SUM is used by default.
########################################################################
dist.reduce(tensor=acc, dst=0)
total_loss /= step + 1
dist.reduce(tensor=total_loss, dst=0)

total_loss = weighted_reduce(total_loss, num_train_items)
acc = weighted_reduce(acc * num_val_items, num_val_items)

# We synchronize before measuring the epoch time.
torch.cuda.synchronize()
epoch_end = time.time()
if rank == 0:
print(
f"Epoch {epoch:05d} | "
f"Average Loss {total_loss.item() / world_size:.4f} | "
f"Accuracy {acc.item() / world_size:.4f} | "
f"Average Loss {total_loss.item():.4f} | "
f"Accuracy {acc.item():.4f} | "
f"Time {epoch_end - epoch_start:.4f}"
)

Expand Down Expand Up @@ -325,7 +336,6 @@ def run(rank, world_size, args, devices, dataset):
if rank == 0:
print("Training...")
train(
world_size,
rank,
args,
train_dataloader,
Expand All @@ -338,18 +348,15 @@ def run(rank, world_size, args, devices, dataset):
# Test the model.
if rank == 0:
print("Testing...")
test_acc = (
evaluate(
rank,
model,
test_dataloader,
num_classes,
device,
)
/ world_size
test_acc, num_test_items = evaluate(
rank,
model,
test_dataloader,
num_classes,
device,
)
dist.reduce(tensor=test_acc, dst=0)
torch.cuda.synchronize()
test_acc = weighted_reduce(test_acc * num_test_items, num_test_items)

if rank == 0:
print(f"Test Accuracy {test_acc.item():.4f}")

Expand Down Expand Up @@ -394,6 +401,14 @@ def parse_args():
default=0,
help="The capacity of the GPU cache, the number of features to store.",
)
parser.add_argument(
"--dataset",
type=str,
default="ogbn-products",
choices=["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"],
help="The dataset we can use for node classification example. Currently"
" ogbn-products, ogbn-arxiv, ogbn-papers100M datasets are supported.",
)
parser.add_argument(
"--mode",
default="pinned-cuda",
Expand All @@ -417,7 +432,7 @@ def parse_args():
print(f"Training with {world_size} gpus.")

# Load and preprocess dataset.
dataset = gb.BuiltinDataset("ogbn-products").load()
dataset = gb.BuiltinDataset(args.dataset).load()

# Thread limiting to avoid resource competition.
os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // world_size)
Expand Down
3 changes: 2 additions & 1 deletion examples/sampling/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,9 @@ def parse_args():
"--dataset",
type=str,
default="ogbn-products",
choices=["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"],
mfbalin marked this conversation as resolved.
Show resolved Hide resolved
help="The dataset we can use for node classification example. Currently"
"dataset ogbn-products, ogbn-arxiv, ogbn-papers100M is supported.",
" ogbn-products, ogbn-arxiv, ogbn-papers100M datasets are supported.",
)
parser.add_argument(
"--mode",
Expand Down
4 changes: 1 addition & 3 deletions tutorials/multi/2_node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def evaluate(rank, model, graph, features, itemset, num_classes, device):


def train(
world_size,
rank,
graph,
features,
Expand Down Expand Up @@ -233,7 +232,7 @@ def train(
loss.backward()
optimizer.step()

total_loss += loss * y.size(0)
total_loss += loss.detach() * y.size(0)
num_train_items += y.size(0)

# Evaluate the model.
Expand Down Expand Up @@ -304,7 +303,6 @@ def run(rank, world_size, devices, dataset):
if rank == 0:
print("Training...")
train(
world_size,
rank,
graph,
features,
Expand Down
Loading