diff --git a/python/dgl/graphbolt/impl/ondisk_dataset.py b/python/dgl/graphbolt/impl/ondisk_dataset.py index e636e17e8f31..ca95bcf8f3f1 100644 --- a/python/dgl/graphbolt/impl/ondisk_dataset.py +++ b/python/dgl/graphbolt/impl/ondisk_dataset.py @@ -7,14 +7,14 @@ from copy import deepcopy from typing import Dict, List, Union +import numpy as np + import torch import yaml -import dgl - from ...base import dgl_warning from ...data.utils import download, extract_archive -from ..base import etype_str_to_tuple +from ..base import etype_str_to_tuple, ORIGINAL_EDGE_ID from ..dataset import Dataset, Task from ..internal import ( calculate_dir_hash, @@ -26,7 +26,10 @@ ) from ..itemset import ItemSet, ItemSetDict from ..sampling_graph import SamplingGraph -from .fused_csc_sampling_graph import from_dglgraph, FusedCSCSamplingGraph +from .fused_csc_sampling_graph import ( + fused_csc_sampling_graph, + FusedCSCSamplingGraph, +) from .ondisk_metadata import ( OnDiskGraphTopology, OnDiskMetaData, @@ -38,6 +41,240 @@ __all__ = ["OnDiskDataset", "preprocess_ondisk_dataset", "BuiltinDataset"] +def _graph_data_to_fused_csc_sampling_graph( + dataset_dir: str, + graph_data: Dict, + include_original_edge_id: bool, +) -> FusedCSCSamplingGraph: + """Convert the raw graph data into FusedCSCSamplingGraph. + + Parameters + ---------- + dataset_dir : str + The path to the dataset directory. + graph_data : Dict + The raw data read from yaml file. + include_original_edge_id : bool + Whether to include the original edge id in the FusedCSCSamplingGraph. + + Returns + ------- + sampling_graph : FusedCSCSamplingGraph + The FusedCSCSamplingGraph constructed from the raw data. + """ + from ...sparse import spmatrix + + is_homogeneous = ( + len(graph_data["nodes"]) == 1 + and len(graph_data["edges"]) == 1 + and "type" not in graph_data["nodes"][0] + and "type" not in graph_data["edges"][0] + ) + + if is_homogeneous: + # Homogeneous graph. + edge_fmt = graph_data["edges"][0]["format"] + edge_path = graph_data["edges"][0]["path"] + src, dst = read_edges(dataset_dir, edge_fmt, edge_path) + num_nodes = graph_data["nodes"][0]["num"] + num_edges = len(src) + coo_tensor = torch.tensor(np.array([src, dst])) + sparse_matrix = spmatrix(coo_tensor, shape=(num_nodes, num_nodes)) + del coo_tensor + indptr, indices, edge_ids = sparse_matrix.csc() + del sparse_matrix + node_type_offset = None + type_per_edge = None + node_type_to_id = None + edge_type_to_id = None + node_attributes = {} + edge_attributes = {} + if include_original_edge_id: + edge_attributes[ORIGINAL_EDGE_ID] = edge_ids + else: + # Heterogeneous graph. + # Sort graph_data by ntype/etype lexicographically to ensure ordering. + graph_data["nodes"].sort(key=lambda x: x["type"]) + graph_data["edges"].sort(key=lambda x: x["type"]) + # Construct node_type_offset and node_type_to_id. + node_type_offset = [0] + node_type_to_id = {} + for ntype_id, node_info in enumerate(graph_data["nodes"]): + node_type_to_id[node_info["type"]] = ntype_id + node_type_offset.append(node_type_offset[-1] + node_info["num"]) + total_num_nodes = node_type_offset[-1] + # Construct edge_type_offset, edge_type_to_id and coo_tensor. + edge_type_offset = [0] + edge_type_to_id = {} + coo_src_list = [] + coo_dst_list = [] + coo_etype_list = [] + for etype_id, edge_info in enumerate(graph_data["edges"]): + edge_type_to_id[edge_info["type"]] = etype_id + edge_fmt = edge_info["format"] + edge_path = edge_info["path"] + src, dst = read_edges(dataset_dir, edge_fmt, edge_path) + edge_type_offset.append(edge_type_offset[-1] + len(src)) + src_type, _, dst_type = etype_str_to_tuple(edge_info["type"]) + src += node_type_offset[node_type_to_id[src_type]] + dst += node_type_offset[node_type_to_id[dst_type]] + coo_src_list.append(torch.tensor(src)) + coo_dst_list.append(torch.tensor(dst)) + coo_etype_list.append(torch.full((len(src),), etype_id)) + total_num_edges = edge_type_offset[-1] + + coo_src = torch.cat(coo_src_list) + del coo_src_list + coo_dst = torch.cat(coo_dst_list) + del coo_dst_list + coo_etype = torch.cat(coo_etype_list) + del coo_etype_list + + sparse_matrix = spmatrix( + indices=torch.stack((coo_src, coo_dst), dim=0), + shape=(total_num_nodes, total_num_nodes), + ) + del coo_src, coo_dst + indptr, indices, edge_ids = sparse_matrix.csc() + del sparse_matrix + node_type_offset = torch.tensor(node_type_offset) + type_per_edge = torch.index_select(coo_etype, dim=0, index=edge_ids) + del coo_etype + node_attributes = {} + edge_attributes = {} + if include_original_edge_id: + edge_ids -= torch.gather( + input=torch.tensor(edge_type_offset), + dim=0, + index=type_per_edge, + ) + edge_attributes[ORIGINAL_EDGE_ID] = edge_ids + + # Load the sampling related node/edge features and add them to + # the sampling-graph. + if graph_data.get("feature_data", None): + if is_homogeneous: + # Homogeneous graph. + for graph_feature in graph_data["feature_data"]: + in_memory = ( + True + if "in_memory" not in graph_feature + else graph_feature["in_memory"] + ) + if graph_feature["domain"] == "node": + node_data = read_data( + os.path.join(dataset_dir, graph_feature["path"]), + graph_feature["format"], + in_memory=in_memory, + ) + assert node_data.shape[0] == num_nodes + node_attributes[graph_feature["name"]] = node_data + elif graph_feature["domain"] == "edge": + edge_data = read_data( + os.path.join(dataset_dir, graph_feature["path"]), + graph_feature["format"], + in_memory=in_memory, + ) + assert edge_data.shape[0] == num_edges + edge_attributes[graph_feature["name"]] = edge_data + else: + # Heterogeneous graph. + node_feature_collector = {} + edge_feature_collector = {} + for graph_feature in graph_data["feature_data"]: + in_memory = ( + True + if "in_memory" not in graph_feature + else graph_feature["in_memory"] + ) + if graph_feature["domain"] == "node": + node_data = read_data( + os.path.join(dataset_dir, graph_feature["path"]), + graph_feature["format"], + in_memory=in_memory, + ) + if graph_feature["name"] not in node_feature_collector: + node_feature_collector[graph_feature["name"]] = {} + node_feature_collector[graph_feature["name"]][ + graph_feature["type"] + ] = node_data + elif graph_feature["domain"] == "edge": + edge_data = read_data( + os.path.join(dataset_dir, graph_feature["path"]), + graph_feature["format"], + in_memory=in_memory, + ) + if graph_feature["name"] not in edge_feature_collector: + edge_feature_collector[graph_feature["name"]] = {} + edge_feature_collector[graph_feature["name"]][ + graph_feature["type"] + ] = edge_data + + # For heterogenous, a node/edge feature must cover all node/edge types. + all_node_types = set(node_type_to_id.keys()) + for feat_name, feat_data in node_feature_collector.items(): + existing_node_type = set(feat_data.keys()) + assert all_node_types == existing_node_type, ( + f"Node feature {feat_name} does not cover all node types. " + f"Existing types: {existing_node_type}. " + f"Expected types: {all_node_types}." + ) + all_edge_types = set(edge_type_to_id.keys()) + for feat_name, feat_data in edge_feature_collector.items(): + existing_edge_type = set(feat_data.keys()) + assert all_edge_types == existing_edge_type, ( + f"Edge feature {feat_name} does not cover all edge types. " + f"Existing types: {existing_edge_type}. " + f"Expected types: {all_edge_types}." + ) + + for feat_name, feat_data in node_feature_collector.items(): + _feat = next(iter(feat_data.values())) + feat_tensor = torch.empty( + ([total_num_nodes] + list(_feat.shape[1:])), + dtype=_feat.dtype, + ) + for ntype, feat in feat_data.items(): + feat_tensor[ + node_type_offset[ + node_type_to_id[ntype] + ] : node_type_offset[node_type_to_id[ntype] + 1] + ] = feat + node_attributes[feat_name] = feat_tensor + del node_feature_collector + for feat_name, feat_data in edge_feature_collector.items(): + _feat = next(iter(feat_data.values())) + feat_tensor = torch.empty( + ([total_num_edges] + list(_feat.shape[1:])), + dtype=_feat.dtype, + ) + for etype, feat in feat_data.items(): + feat_tensor[ + edge_type_offset[ + edge_type_to_id[etype] + ] : edge_type_offset[edge_type_to_id[etype] + 1] + ] = feat + edge_attributes[feat_name] = feat_tensor + del edge_feature_collector + + if not bool(node_attributes): + node_attributes = None + if not bool(edge_attributes): + edge_attributes = None + + # Construct the FusedCSCSamplingGraph. + return fused_csc_sampling_graph( + csc_indptr=indptr, + indices=indices, + node_type_offset=node_type_offset, + type_per_edge=type_per_edge, + node_type_to_id=node_type_to_id, + edge_type_to_id=edge_type_to_id, + node_attributes=node_attributes, + edge_attributes=edge_attributes, + ) + + def preprocess_ondisk_dataset( dataset_dir: str, include_original_edge_id: bool = False, @@ -115,108 +352,20 @@ def preprocess_ondisk_dataset( os.makedirs(os.path.join(dataset_dir, processed_dir_prefix), exist_ok=True) output_config = deepcopy(input_config) - # 2. Load the edge data and create a DGLGraph. + # 2. Load the data and create a FusedCSCSamplingGraph. if "graph" not in input_config: raise RuntimeError("Invalid config: does not contain graph field.") - # For any graph that node/edge types are specified, we construct DGLGraph - # with `dgl.heterograph()` even there's only one node/edge type. This is - # because we want to save the node/edge types in the graph. So the logic of - # checking whether the graph is homogeneous is different from the logic in - # `DGLGraph.is_homogeneous()`. Otherwise, we construct DGLGraph with - # `dgl.graph()`. - is_homogeneous = ( - len(input_config["graph"]["nodes"]) == 1 - and len(input_config["graph"]["edges"]) == 1 - and "type" not in input_config["graph"]["nodes"][0] - and "type" not in input_config["graph"]["edges"][0] - ) - if is_homogeneous: - # Homogeneous graph. - num_nodes = input_config["graph"]["nodes"][0]["num"] - edge_fmt = input_config["graph"]["edges"][0]["format"] - edge_path = input_config["graph"]["edges"][0]["path"] - src, dst = read_edges(dataset_dir, edge_fmt, edge_path) - g = dgl.graph((src, dst), num_nodes=num_nodes) - else: - # Heterogeneous graph. - # Construct the num nodes dict. - num_nodes_dict = {} - for node_info in input_config["graph"]["nodes"]: - num_nodes_dict[node_info["type"]] = node_info["num"] - # Construct the data dict. - data_dict = {} - for edge_info in input_config["graph"]["edges"]: - edge_fmt = edge_info["format"] - edge_path = edge_info["path"] - src, dst = read_edges(dataset_dir, edge_fmt, edge_path) - data_dict[etype_str_to_tuple(edge_info["type"])] = (src, dst) - # Construct the heterograph. - g = dgl.heterograph(data_dict, num_nodes_dict) - - # 3. Load the sampling related node/edge features and add them to - # the sampling-graph. - if input_config["graph"].get("feature_data", None): - for graph_feature in input_config["graph"]["feature_data"]: - in_memory = ( - True - if "in_memory" not in graph_feature - else graph_feature["in_memory"] - ) - if graph_feature["domain"] == "node": - node_data = read_data( - os.path.join(dataset_dir, graph_feature["path"]), - graph_feature["format"], - in_memory=in_memory, - ) - if is_homogeneous: - g.ndata[graph_feature["name"]] = node_data - else: - g.nodes[graph_feature["type"]].data[ - graph_feature["name"] - ] = node_data - if graph_feature["domain"] == "edge": - edge_data = read_data( - os.path.join(dataset_dir, graph_feature["path"]), - graph_feature["format"], - in_memory=in_memory, - ) - if is_homogeneous: - g.edata[graph_feature["name"]] = edge_data - else: - g.edges[etype_str_to_tuple(graph_feature["type"])].data[ - graph_feature["name"] - ] = edge_data - if not is_homogeneous: - # For heterogenous graph, a node/edge feature must cover all - # node/edge types. - ntypes = g.ntypes - assert all( - set(g.nodes[ntypes[0]].data.keys()) - == set(g.nodes[ntype].data.keys()) - for ntype in ntypes - ), ( - "Node feature does not cover all node types: " - + f"{set(g.nodes[ntype].data.keys() for ntype in ntypes)}." - ) - etypes = g.canonical_etypes - assert all( - set(g.edges[etypes[0]].data.keys()) - == set(g.edges[etype].data.keys()) - for etype in etypes - ), ( - "Edge feature does not cover all edge types: " - + f"{set(g.edges[etype].data.keys() for etype in etypes)}." - ) - # 4. Convert the DGLGraph to a FusedCSCSamplingGraph. - fused_csc_sampling_graph = from_dglgraph( - g, is_homogeneous, include_original_edge_id + sampling_graph = _graph_data_to_fused_csc_sampling_graph( + dataset_dir, + input_config["graph"], + include_original_edge_id, ) - # 5. Record value of include_original_edge_id. + # 3. Record value of include_original_edge_id. output_config["include_original_edge_id"] = include_original_edge_id - # 6. Save the FusedCSCSamplingGraph and modify the output_config. + # 4. Save the FusedCSCSamplingGraph and modify the output_config. output_config["graph_topology"] = {} output_config["graph_topology"]["type"] = "FusedCSCSamplingGraph" output_config["graph_topology"]["path"] = os.path.join( @@ -224,7 +373,7 @@ def preprocess_ondisk_dataset( ) torch.save( - fused_csc_sampling_graph, + sampling_graph, os.path.join( dataset_dir, output_config["graph_topology"]["path"], @@ -232,7 +381,7 @@ def preprocess_ondisk_dataset( ) del output_config["graph"] - # 7. Load the node/edge features and do necessary conversion. + # 5. Load the node/edge features and do necessary conversion. if input_config.get("feature_data", None): has_edge_feature_data = False for feature, out_feature in zip( @@ -259,7 +408,7 @@ def preprocess_ondisk_dataset( if has_edge_feature_data and not include_original_edge_id: dgl_warning("Edge feature is stored, but edge IDs are not saved.") - # 8. Save tasks and train/val/test split according to the output_config. + # 6. Save tasks and train/val/test split according to the output_config. if input_config.get("tasks", None): for input_task, output_task in zip( input_config["tasks"], output_config["tasks"] @@ -286,13 +435,13 @@ def preprocess_ondisk_dataset( output_data["format"], ) - # 9. Save the output_config. + # 7. Save the output_config. output_config_path = os.path.join(dataset_dir, preprocess_metadata_path) with open(output_config_path, "w") as f: yaml.dump(output_config, f) print("Finish preprocessing the on-disk dataset.") - # 10. Calculate and save the hash value of the dataset directory. + # 8. Calculate and save the hash value of the dataset directory. hash_value_file = "dataset_hash_value.txt" hash_value_file_path = os.path.join( dataset_dir, processed_dir_prefix, hash_value_file @@ -303,7 +452,7 @@ def preprocess_ondisk_dataset( with open(hash_value_file_path, "w") as f: f.write(json.dumps(dir_hash, indent=4)) - # 11. Return the absolute path of the preprocessing yaml file. + # 9. Return the absolute path of the preprocessing yaml file. return output_config_path diff --git a/tests/python/pytorch/graphbolt/gb_test_utils.py b/tests/python/pytorch/graphbolt/gb_test_utils.py index 59c4c3a90276..005d99b2cba3 100644 --- a/tests/python/pytorch/graphbolt/gb_test_utils.py +++ b/tests/python/pytorch/graphbolt/gb_test_utils.py @@ -92,8 +92,10 @@ def random_homo_graphbolt_graph( ): """Generate random graphbolt version homograph""" # Generate random edges. - nodes = np.repeat(np.arange(num_nodes), 5) - neighbors = np.random.randint(0, num_nodes, size=(num_edges)) + nodes = np.repeat(np.arange(num_nodes, dtype=np.int64), 5) + neighbors = np.random.randint( + 0, num_nodes, size=(num_edges), dtype=np.int64 + ) edges = np.stack([nodes, neighbors], axis=1) os.makedirs(os.path.join(test_dir, "edges"), exist_ok=True) assert edge_fmt in ["numpy", "csv"], print( @@ -101,9 +103,9 @@ def random_homo_graphbolt_graph( ) if edge_fmt == "csv": # Wrtie into edges/edge.csv - edges = pd.DataFrame(edges, columns=["src", "dst"]) + edges_DataFrame = pd.DataFrame(edges, columns=["src", "dst"]) edge_path = os.path.join("edges", "edge.csv") - edges.to_csv( + edges_DataFrame.to_csv( os.path.join(test_dir, edge_path), index=False, header=False, @@ -136,7 +138,7 @@ def random_homo_graphbolt_graph( np.arange(each_set_size), np.arange(each_set_size, 2 * each_set_size), ) - train_data = np.vstack(train_pairs).T.astype(np.int64) + train_data = np.vstack(train_pairs).T.astype(edges.dtype) train_path = os.path.join("set", "train.npy") np.save(os.path.join(test_dir, train_path), train_data) @@ -144,7 +146,7 @@ def random_homo_graphbolt_graph( np.arange(each_set_size, 2 * each_set_size), np.arange(2 * each_set_size, 3 * each_set_size), ) - validation_data = np.vstack(validation_pairs).T.astype(np.int64) + validation_data = np.vstack(validation_pairs).T.astype(edges.dtype) validation_path = os.path.join("set", "validation.npy") np.save(os.path.join(test_dir, validation_path), validation_data) @@ -152,7 +154,7 @@ def random_homo_graphbolt_graph( np.arange(2 * each_set_size, 3 * each_set_size), np.arange(3 * each_set_size, 4 * each_set_size), ) - test_data = np.vstack(test_pairs).T.astype(np.int64) + test_data = np.vstack(test_pairs).T.astype(edges.dtype) test_path = os.path.join("set", "test.npy") np.save(os.path.join(test_dir, test_path), test_data) diff --git a/tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py b/tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py index c1e02b0efca3..f5459b316517 100644 --- a/tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py +++ b/tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py @@ -1211,7 +1211,8 @@ def test_OnDiskDataset_preprocess_homogeneous_hardcode(edge_fmt="numpy"): # Generate edges. edges = np.array( - [[0, 0, 1, 1, 2, 2, 3, 3, 4, 4], [1, 2, 2, 3, 3, 4, 4, 0, 0, 1]] + [[0, 0, 1, 1, 2, 2, 3, 3, 4, 4], [1, 2, 2, 3, 3, 4, 4, 0, 0, 1]], + dtype=np.int64, ).T os.makedirs(os.path.join(test_dir, "edges"), exist_ok=True) edges = edges.T @@ -1220,14 +1221,18 @@ def test_OnDiskDataset_preprocess_homogeneous_hardcode(edge_fmt="numpy"): # Generate graph edge-feats. edge_feats = np.array( - [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9] + [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9], + dtype=np.float64, ) os.makedirs(os.path.join(test_dir, "data"), exist_ok=True) edge_feat_path = os.path.join("data", "edge-feat.npy") np.save(os.path.join(test_dir, edge_feat_path), edge_feats) # Generate node-feats. - node_feats = np.array([0.0, 1.9, 2.8, 3.7, 4.6]) + node_feats = np.array( + [0.0, 1.9, 2.8, 3.7, 4.6], + dtype=np.float64, + ) node_feat_path = os.path.join("data", "node-feat.npy") np.save(os.path.join(test_dir, node_feat_path), node_feats) @@ -1391,45 +1396,50 @@ def test_OnDiskDataset_preprocess_heterogeneous_hardcode(edge_fmt="numpy"): # Generate edges. os.makedirs(os.path.join(test_dir, "edges"), exist_ok=True) np.save( - os.path.join(test_dir, "edges", "a_a.npy"), np.array([[0], [1]]) + os.path.join(test_dir, "edges", "a_a.npy"), + np.array([[0], [1]], dtype=np.int64), ) np.save( os.path.join(test_dir, "edges", "a_b.npy"), - np.array([[0, 1, 1], [0, 0, 1]]), + np.array([[0, 1, 1], [0, 0, 1]], dtype=np.int64), ) np.save( os.path.join(test_dir, "edges", "b_b.npy"), - np.array([[0, 0, 1], [1, 2, 2]]), + np.array([[0, 0, 1], [1, 2, 2]], dtype=np.int64), ) np.save( os.path.join(test_dir, "edges", "b_a.npy"), - np.array([[1, 2, 2], [0, 0, 1]]), + np.array([[1, 2, 2], [0, 0, 1]], dtype=np.int64), ) # Generate node features. os.makedirs(os.path.join(test_dir, "data"), exist_ok=True) np.save( - os.path.join(test_dir, "data", "A-feat.npy"), np.array([0.0, 1.9]) + os.path.join(test_dir, "data", "A-feat.npy"), + np.array([0.0, 1.9], dtype=np.float64), ) np.save( os.path.join(test_dir, "data", "B-feat.npy"), - np.array([2.8, 3.7, 4.6]), + np.array([2.8, 3.7, 4.6], dtype=np.float64), ) # Generate edge features. os.makedirs(os.path.join(test_dir, "data"), exist_ok=True) - np.save(os.path.join(test_dir, "data", "a_a-feat.npy"), np.array([0.0])) + np.save( + os.path.join(test_dir, "data", "a_a-feat.npy"), + np.array([0.0], dtype=np.float64), + ) np.save( os.path.join(test_dir, "data", "a_b-feat.npy"), - np.array([1.1, 2.2, 3.3]), + np.array([1.1, 2.2, 3.3], dtype=np.float64), ) np.save( os.path.join(test_dir, "data", "b_b-feat.npy"), - np.array([4.4, 5.5, 6.6]), + np.array([4.4, 5.5, 6.6], dtype=np.float64), ) np.save( os.path.join(test_dir, "data", "b_a-feat.npy"), - np.array([7.7, 8.8, 9.9]), + np.array([7.7, 8.8, 9.9], dtype=np.float64), ) yaml_content = (