diff --git a/python/dgl/graphbolt/minibatch.py b/python/dgl/graphbolt/minibatch.py index aef6b31988a2..f75c72ab005c 100644 --- a/python/dgl/graphbolt/minibatch.py +++ b/python/dgl/graphbolt/minibatch.py @@ -500,7 +500,7 @@ def to_pyg_data(self): col_nodes = torch.cat(col_nodes) row_nodes = torch.cat(row_nodes) edge_index = torch.unique( - torch.stack((col_nodes, row_nodes)), dim=1 + torch.stack((row_nodes, col_nodes)), dim=1 ) if self.node_features is None: diff --git a/tests/python/pytorch/graphbolt/test_minibatch.py b/tests/python/pytorch/graphbolt/test_minibatch.py index 1f708c84c664..7a428fd828d9 100644 --- a/tests/python/pytorch/graphbolt/test_minibatch.py +++ b/tests/python/pytorch/graphbolt/test_minibatch.py @@ -881,7 +881,7 @@ def test_to_pyg_data(): original_column_node_ids=torch.tensor([10, 11]), ) expected_edge_index = torch.tensor( - [[0, 0, 1, 1, 1, 2, 2, 3], [0, 1, 0, 1, 2, 1, 2, 2]] + [[0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 0, 1, 2, 1, 2, 3]] ) expected_node_features = torch.tensor([[1], [2], [3], [4]]) expected_labels = torch.tensor([0, 1])