diff --git a/examples/pytorch/hgp_sl/functions.py b/examples/pytorch/hgp_sl/functions.py index 3c22261f2e6b..c12d72d25a8f 100644 --- a/examples/pytorch/hgp_sl/functions.py +++ b/examples/pytorch/hgp_sl/functions.py @@ -9,10 +9,10 @@ """ import dgl import torch +from dgl._sparse_ops import _gsddmm, _gspmm from dgl.backend import astype from dgl.base import ALL, is_all from dgl.heterograph_index import HeteroGraphIndex -from dgl.sparse import _gsddmm, _gspmm from torch import Tensor from torch.autograd import Function diff --git a/tests/python/pytorch/graphbolt/test_subgraph_sampler.py b/tests/python/pytorch/graphbolt/test_subgraph_sampler.py index a5c8ef53c305..d8b80381cc0b 100644 --- a/tests/python/pytorch/graphbolt/test_subgraph_sampler.py +++ b/tests/python/pytorch/graphbolt/test_subgraph_sampler.py @@ -1,4 +1,5 @@ import unittest +import warnings from enum import Enum from functools import partial @@ -9,7 +10,6 @@ import dgl.graphbolt as gb import pytest import torch -from torchdata.datapipes.iter import Mapper from . import gb_test_utils @@ -22,6 +22,12 @@ def _check_sampler_type(sampler_type): ) +def _check_sampler_len(sampler, lenExp): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + assert len(list(sampler)) == lenExp + + class SamplerType(Enum): Normal = 0 Layer = 1 @@ -128,7 +134,7 @@ def test_SubgraphSampler_Node_seed_nodes(sampler_type): fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] sampler = _get_sampler(sampler_type) sampler_dp = sampler(item_sampler, graph, fanouts) - assert len(list(sampler_dp)) == 5 + _check_sampler_len(sampler_dp, 5) def to_link_batch(data): @@ -161,7 +167,7 @@ def test_SubgraphSampler_Link_node_pairs(sampler_type): sampler = _get_sampler(sampler_type) datapipe = sampler(datapipe, graph, fanouts) datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) - assert len(list(datapipe)) == 5 + _check_sampler_len(datapipe, 5) @pytest.mark.parametrize( @@ -190,7 +196,7 @@ def test_SubgraphSampler_Link_With_Negative_node_pairs(sampler_type): sampler = _get_sampler(sampler_type) datapipe = sampler(datapipe, graph, fanouts) datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) - assert len(list(datapipe)) == 5 + _check_sampler_len(datapipe, 5) def get_hetero_graph(): @@ -239,9 +245,11 @@ def test_SubgraphSampler_Node_seed_nodes_Hetero(sampler_type): fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] sampler = _get_sampler(sampler_type) sampler_dp = sampler(item_sampler, graph, fanouts) - assert len(list(sampler_dp)) == 2 - for minibatch in sampler_dp: - assert len(minibatch.sampled_subgraphs) == num_layer + _check_sampler_len(sampler_dp, 2) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + for minibatch in sampler_dp: + assert len(minibatch.sampled_subgraphs) == num_layer @pytest.mark.parametrize( @@ -285,7 +293,7 @@ def test_SubgraphSampler_Link_Hetero_node_pairs(sampler_type): sampler = _get_sampler(sampler_type) datapipe = sampler(datapipe, graph, fanouts) datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) - assert len(list(datapipe)) == 5 + _check_sampler_len(datapipe, 5) @pytest.mark.parametrize( @@ -330,7 +338,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_node_pairs(sampler_type): sampler = _get_sampler(sampler_type) datapipe = sampler(datapipe, graph, fanouts) datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) - assert len(list(datapipe)) == 5 + _check_sampler_len(datapipe, 5) @pytest.mark.parametrize( @@ -375,7 +383,7 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype_node_pairs(sampler_type): sampler = _get_sampler(sampler_type) datapipe = sampler(datapipe, graph, fanouts) datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) - assert len(list(datapipe)) == 5 + _check_sampler_len(datapipe, 5) @pytest.mark.parametrize( @@ -423,7 +431,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype_node_pairs( sampler = _get_sampler(sampler_type) datapipe = sampler(datapipe, graph, fanouts) datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) - assert len(list(datapipe)) == 5 + _check_sampler_len(datapipe, 5) @pytest.mark.parametrize( @@ -493,32 +501,28 @@ def test_SubgraphSampler_Random_Hetero_Graph_seed_ndoes(sampler_type, replace): sampler_dp = sampler(item_sampler, graph, fanouts, replace=replace) - for data in sampler_dp: - for sampledsubgraph in data.sampled_subgraphs: - for _, value in sampledsubgraph.sampled_csc.items(): - assert torch.equal( - torch.ge( - value.indices, - torch.zeros(len(value.indices)).to(F.ctx()), - ), - torch.ones(len(value.indices)).to(F.ctx()), - ) - assert torch.equal( - torch.ge( - value.indptr, torch.zeros(len(value.indptr)).to(F.ctx()) - ), - torch.ones(len(value.indptr)).to(F.ctx()), - ) - for _, value in sampledsubgraph.original_column_node_ids.items(): - assert torch.equal( - torch.ge(value, torch.zeros(len(value)).to(F.ctx())), - torch.ones(len(value)).to(F.ctx()), - ) - for _, value in sampledsubgraph.original_row_node_ids.items(): - assert torch.equal( - torch.ge(value, torch.zeros(len(value)).to(F.ctx())), - torch.ones(len(value)).to(F.ctx()), - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + for data in sampler_dp: + for sampledsubgraph in data.sampled_subgraphs: + for _, value in sampledsubgraph.sampled_csc.items(): + for idx in [value.indices, value.indptr]: + assert torch.equal( + torch.ge(idx, torch.zeros(len(idx)).to(F.ctx())), + torch.ones(len(idx)).to(F.ctx()), + ) + node_ids = [ + sampledsubgraph.original_column_node_ids, + sampledsubgraph.original_row_node_ids, + ] + for ids in node_ids: + for _, value in ids.items(): + assert torch.equal( + torch.ge( + value, torch.zeros(len(value)).to(F.ctx()) + ), + torch.ones(len(value)).to(F.ctx()), + ) @pytest.mark.parametrize( @@ -570,9 +574,60 @@ def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type): torch.tensor([0, 2, 2, 3, 4, 4, 5]).to(F.ctx()), torch.tensor([0, 3, 4]).to(F.ctx()), ] + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + for data in datapipe: + for step, sampled_subgraph in enumerate(data.sampled_subgraphs): + assert ( + len(sampled_subgraph.original_row_node_ids) == length[step] + ) + assert torch.equal( + sampled_subgraph.sampled_csc.indices, + compacted_indices[step], + ) + assert torch.equal( + sampled_subgraph.sampled_csc.indptr, indptr[step] + ) + assert torch.equal( + torch.sort(sampled_subgraph.original_column_node_ids)[0], + seeds[step], + ) + + +def _assert_hetero_values( + datapipe, original_row_node_ids, original_column_node_ids, csc_formats +): for data in datapipe: for step, sampled_subgraph in enumerate(data.sampled_subgraphs): - assert len(sampled_subgraph.original_row_node_ids) == length[step] + for ntype in ["n1", "n2"]: + assert torch.equal( + sampled_subgraph.original_row_node_ids[ntype], + original_row_node_ids[step][ntype].to(F.ctx()), + ) + assert torch.equal( + sampled_subgraph.original_column_node_ids[ntype], + original_column_node_ids[step][ntype].to(F.ctx()), + ) + for etype in ["n1:e1:n2", "n2:e2:n1"]: + assert torch.equal( + sampled_subgraph.sampled_csc[etype].indices, + csc_formats[step][etype].indices.to(F.ctx()), + ) + assert torch.equal( + sampled_subgraph.sampled_csc[etype].indptr, + csc_formats[step][etype].indptr.to(F.ctx()), + ) + + +def _assert_homo_values( + datapipe, original_row_node_ids, compacted_indices, indptr, seeds +): + for data in datapipe: + for step, sampled_subgraph in enumerate(data.sampled_subgraphs): + assert torch.equal( + sampled_subgraph.original_row_node_ids, + original_row_node_ids[step], + ) assert torch.equal( sampled_subgraph.sampled_csc.indices, compacted_indices[step] ) @@ -580,8 +635,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type): sampled_subgraph.sampled_csc.indptr, indptr[step] ) assert torch.equal( - torch.sort(sampled_subgraph.original_column_node_ids)[0], - seeds[step], + sampled_subgraph.original_column_node_ids, seeds[step] ) @@ -655,26 +709,14 @@ def test_SubgraphSampler_without_dedpulication_Hetero_seed_nodes(sampler_type): }, ] - for data in datapipe: - for step, sampled_subgraph in enumerate(data.sampled_subgraphs): - for ntype in ["n1", "n2"]: - assert torch.equal( - sampled_subgraph.original_row_node_ids[ntype], - original_row_node_ids[step][ntype].to(F.ctx()), - ) - assert torch.equal( - sampled_subgraph.original_column_node_ids[ntype], - original_column_node_ids[step][ntype].to(F.ctx()), - ) - for etype in ["n1:e1:n2", "n2:e2:n1"]: - assert torch.equal( - sampled_subgraph.sampled_csc[etype].indices, - csc_formats[step][etype].indices.to(F.ctx()), - ) - assert torch.equal( - sampled_subgraph.sampled_csc[etype].indptr, - csc_formats[step][etype].indptr.to(F.ctx()), - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + _assert_hetero_values( + datapipe, + original_row_node_ids, + original_column_node_ids, + csc_formats, + ) @unittest.skipIf( @@ -719,21 +761,9 @@ def test_SubgraphSampler_unique_csc_format_Homo_cpu_seed_nodes(labor): torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()), torch.tensor([0, 3, 4]).to(F.ctx()), ] - for data in datapipe: - for step, sampled_subgraph in enumerate(data.sampled_subgraphs): - assert torch.equal( - sampled_subgraph.original_row_node_ids, - original_row_node_ids[step], - ) - assert torch.equal( - sampled_subgraph.sampled_csc.indices, compacted_indices[step] - ) - assert torch.equal( - sampled_subgraph.sampled_csc.indptr, indptr[step] - ) - assert torch.equal( - sampled_subgraph.original_column_node_ids, seeds[step] - ) + _assert_homo_values( + datapipe, original_row_node_ids, compacted_indices, indptr, seeds + ) @unittest.skipIf( @@ -778,21 +808,9 @@ def test_SubgraphSampler_unique_csc_format_Homo_gpu_seed_nodes(labor): torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()), torch.tensor([0, 3, 4]).to(F.ctx()), ] - for data in datapipe: - for step, sampled_subgraph in enumerate(data.sampled_subgraphs): - assert torch.equal( - sampled_subgraph.original_row_node_ids, - original_row_node_ids[step], - ) - assert torch.equal( - sampled_subgraph.sampled_csc.indices, compacted_indices[step] - ) - assert torch.equal( - sampled_subgraph.sampled_csc.indptr, indptr[step] - ) - assert torch.equal( - sampled_subgraph.original_column_node_ids, seeds[step] - ) + _assert_homo_values( + datapipe, original_row_node_ids, compacted_indices, indptr, seeds + ) @pytest.mark.parametrize("labor", [False, True]) @@ -853,27 +871,9 @@ def test_SubgraphSampler_unique_csc_format_Hetero_seed_nodes(labor): "n2": torch.tensor([0, 1]), }, ] - - for data in datapipe: - for step, sampled_subgraph in enumerate(data.sampled_subgraphs): - for ntype in ["n1", "n2"]: - assert torch.equal( - sampled_subgraph.original_row_node_ids[ntype], - original_row_node_ids[step][ntype].to(F.ctx()), - ) - assert torch.equal( - sampled_subgraph.original_column_node_ids[ntype], - original_column_node_ids[step][ntype].to(F.ctx()), - ) - for etype in ["n1:e1:n2", "n2:e2:n1"]: - assert torch.equal( - sampled_subgraph.sampled_csc[etype].indices, - csc_formats[step][etype].indices.to(F.ctx()), - ) - assert torch.equal( - sampled_subgraph.sampled_csc[etype].indptr, - csc_formats[step][etype].indptr.to(F.ctx()), - ) + _assert_hetero_values( + datapipe, original_row_node_ids, original_column_node_ids, csc_formats + ) @pytest.mark.parametrize( @@ -886,7 +886,9 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type): items_n1 = torch.tensor([0]) items_n2 = torch.tensor([1]) names = "seed_nodes" + item_length = 2 if sampler_type == SamplerType.Temporal: + item_length = 3 graph.node_attributes = { "timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx()) } @@ -909,38 +911,31 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type): fanouts = [torch.LongTensor([2, 1]) for _ in range(num_layer)] sampler = _get_sampler(sampler_type) sampler_dp = sampler(item_sampler, graph, fanouts) - if sampler_type == SamplerType.Temporal: - indices_len = [ - { - "n1:e1:n2": 4, - "n2:e2:n1": 3, - }, - { - "n1:e1:n2": 2, - "n2:e2:n1": 1, - }, - ] - else: - indices_len = [ - { - "n1:e1:n2": 4, - "n2:e2:n1": 2, - }, - { - "n1:e1:n2": 2, - "n2:e2:n1": 1, - }, - ] - for minibatch in sampler_dp: - for step, sampled_subgraph in enumerate(minibatch.sampled_subgraphs): - assert ( - len(sampled_subgraph.sampled_csc["n1:e1:n2"].indices) - == indices_len[step]["n1:e1:n2"] - ) - assert ( - len(sampled_subgraph.sampled_csc["n2:e2:n1"].indices) - == indices_len[step]["n2:e2:n1"] - ) + indices_len = [ + { + "n1:e1:n2": 4, + "n2:e2:n1": item_length, + }, + { + "n1:e1:n2": 2, + "n2:e2:n1": 1, + }, + ] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + for minibatch in sampler_dp: + for step, sampled_subgraph in enumerate( + minibatch.sampled_subgraphs + ): + assert ( + len(sampled_subgraph.sampled_csc["n1:e1:n2"].indices) + == indices_len[step]["n1:e1:n2"] + ) + assert ( + len(sampled_subgraph.sampled_csc["n2:e2:n1"].indices) + == indices_len[step]["n2:e2:n1"] + ) def test_SubgraphSampler_invoke():