From 3fb81d7d9b38ff9872094af1ddb0a3ba293747ec Mon Sep 17 00:00:00 2001 From: Skeleton003 <799284168@qq.com> Date: Sat, 17 Feb 2024 13:32:50 +0000 Subject: [PATCH] 1 --- .../python/pytorch/graphbolt/gb_test_utils.py | 24 ++++++++++--------- .../graphbolt/impl/test_ondisk_dataset.py | 2 +- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/python/pytorch/graphbolt/gb_test_utils.py b/tests/python/pytorch/graphbolt/gb_test_utils.py index 59c4c3a90276..d4ef8ff14401 100644 --- a/tests/python/pytorch/graphbolt/gb_test_utils.py +++ b/tests/python/pytorch/graphbolt/gb_test_utils.py @@ -96,11 +96,12 @@ def random_homo_graphbolt_graph( neighbors = np.random.randint(0, num_nodes, size=(num_edges)) edges = np.stack([nodes, neighbors], axis=1) os.makedirs(os.path.join(test_dir, "edges"), exist_ok=True) - assert edge_fmt in ["numpy", "csv"], print( - "only numpy and csv are supported for edges." - ) + assert edge_fmt in [ + "numpy", + "csv", + ], "Only numpy and csv are supported for edges." if edge_fmt == "csv": - # Wrtie into edges/edge.csv + # Write into edges/edge.csv edges = pd.DataFrame(edges, columns=["src", "dst"]) edge_path = os.path.join("edges", "edge.csv") edges.to_csv( @@ -109,7 +110,7 @@ def random_homo_graphbolt_graph( header=False, ) else: - # Wrtie into edges/edge.npy + # Write into edges/edge.npy edges = edges.T edge_path = os.path.join("edges", "edge.npy") np.save(os.path.join(test_dir, edge_path), edges) @@ -158,7 +159,7 @@ def random_homo_graphbolt_graph( yaml_content = f""" dataset_name: {dataset_name} - graph: # graph structure and required attributes. + graph: # Graph structure and required attributes. nodes: - num: {num_nodes} edges: @@ -217,7 +218,7 @@ def random_homo_graphbolt_graph( return yaml_content -def genereate_raw_data_for_hetero_dataset( +def generate_raw_data_for_hetero_dataset( test_dir, dataset_name, num_nodes, num_edges, num_classes, edge_fmt="csv" ): # Generate edges. @@ -227,9 +228,10 @@ def genereate_raw_data_for_hetero_dataset( src = torch.randint(0, num_nodes[src_ntype], (num_edge,)) dst = torch.randint(0, num_nodes[dst_ntype], (num_edge,)) os.makedirs(os.path.join(test_dir, "edges"), exist_ok=True) - assert edge_fmt in ["numpy", "csv"], print( - "only numpy and csv are supported for edges." - ) + assert edge_fmt in [ + "numpy", + "csv", + ], "Only numpy and csv are supported for edges." if edge_fmt == "csv": # Write into edges/edge.csv edges = pd.DataFrame( @@ -288,7 +290,7 @@ def genereate_raw_data_for_hetero_dataset( yaml_content = f""" dataset_name: {dataset_name} - graph: # graph structure and required attributes. + graph: # Graph structure and required attributes. nodes: - type: user num: {num_nodes["user"]} diff --git a/tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py b/tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py index c1e02b0efca3..876479747b01 100644 --- a/tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py +++ b/tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py @@ -2703,7 +2703,7 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt): ("user", "click", "item"): 20000, } num_classes = 10 - gbt.genereate_raw_data_for_hetero_dataset( + gbt.generate_raw_data_for_hetero_dataset( test_dir, dataset_name, num_nodes,