2
2
3
3
# 4 points on a diagonal line with d^2 = 0.1^2+0.1^2 = 0.02 between them.
4
4
# 1 point very far away.
5
- nodes = torch .FloatTensor ([
6
- # Event 0
7
- [.1 , .1 ],
8
- [.2 , .2 ],
9
- [.3 , .3 ],
10
- [.4 , .4 ],
11
- [100. , 100. ],
12
- # Event 1
13
- [.1 , .1 ],
14
- [.2 , .2 ],
15
- [.3 , .3 ],
16
- [.4 , .4 ]
17
- ])
18
- batch = torch .LongTensor ([0 ,0 ,0 ,0 ,0 ,1 ,1 ,1 ,1 ])
5
+ nodes = torch .FloatTensor (
6
+ [
7
+ # Event 0
8
+ [0.1 , 0.1 ],
9
+ [0.2 , 0.2 ],
10
+ [0.3 , 0.3 ],
11
+ [0.4 , 0.4 ],
12
+ [100.0 , 100.0 ],
13
+ # Event 1
14
+ [0.1 , 0.1 ],
15
+ [0.2 , 0.2 ],
16
+ [0.3 , 0.3 ],
17
+ [0.4 , 0.4 ],
18
+ ]
19
+ )
20
+ batch = torch .LongTensor ([0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 ])
19
21
20
22
# Expected output for k=3, max_radius=0.2 (with loop)
21
23
# Always a connection with self, which has distance 0.0
22
- expected_neigh_indices = torch .IntTensor ([
23
- [ 0 , 1 , - 1 ],
24
- [ 1 , 0 , 2 ],
25
- [ 2 , 1 , 3 ],
26
- [ 3 , 2 , - 1 ],
27
- [ 4 , - 1 , - 1 ],
28
- [ 5 , 6 , - 1 ],
29
- [ 6 , 5 , 7 ],
30
- [ 7 , 6 , 8 ],
31
- [ 8 , 7 , - 1 ]
32
- ])
33
- expected_neigh_dist_sq = torch .FloatTensor ([
34
- [0.0 , 0.02 , 0.00 ],
35
- [0.0 , 0.02 , 0.02 ],
36
- [0.0 , 0.02 , 0.02 ],
37
- [0.0 , 0.02 , 0.00 ],
38
- [0.0 , 0.00 , 0.00 ],
39
- [0.0 , 0.02 , 0.00 ],
40
- [0.0 , 0.02 , 0.02 ],
41
- [0.0 , 0.02 , 0.02 ],
42
- [0.0 , 0.02 , 0.00 ]
43
- ])
44
- expected_edge_index_noloop = torch .LongTensor ([
45
- [0 , 1 , 1 , 2 , 2 , 3 , 5 , 6 , 6 , 7 , 7 , 8 ],
46
- [1 , 0 , 2 , 1 , 3 , 2 , 6 , 5 , 7 , 6 , 8 , 7 ]
47
- ])
48
- expected_edge_index_loop = torch .LongTensor ([
49
- [0 , 0 , 1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 4 , 5 , 5 , 6 , 6 , 6 , 7 , 7 , 7 , 8 , 8 ],
50
- [0 , 1 , 1 , 0 , 2 , 2 , 1 , 3 , 3 , 2 , 4 , 5 , 6 , 6 , 5 , 7 , 7 , 6 , 8 , 8 , 7 ]
51
- ])
24
+ expected_neigh_indices = torch .IntTensor (
25
+ [
26
+ [0 , 1 , - 1 ],
27
+ [1 , 0 , 2 ],
28
+ [2 , 1 , 3 ],
29
+ [3 , 2 , - 1 ],
30
+ [4 , - 1 , - 1 ],
31
+ [5 , 6 , - 1 ],
32
+ [6 , 5 , 7 ],
33
+ [7 , 6 , 8 ],
34
+ [8 , 7 , - 1 ],
35
+ ]
36
+ )
37
+ expected_neigh_dist_sq = torch .FloatTensor (
38
+ [
39
+ [0.0 , 0.02 , 0.00 ],
40
+ [0.0 , 0.02 , 0.02 ],
41
+ [0.0 , 0.02 , 0.02 ],
42
+ [0.0 , 0.02 , 0.00 ],
43
+ [0.0 , 0.00 , 0.00 ],
44
+ [0.0 , 0.02 , 0.00 ],
45
+ [0.0 , 0.02 , 0.02 ],
46
+ [0.0 , 0.02 , 0.02 ],
47
+ [0.0 , 0.02 , 0.00 ],
48
+ ]
49
+ )
50
+ expected_edge_index_noloop = torch .LongTensor (
51
+ [[0 , 1 , 1 , 2 , 2 , 3 , 5 , 6 , 6 , 7 , 7 , 8 ], [1 , 0 , 2 , 1 , 3 , 2 , 6 , 5 , 7 , 6 , 8 , 7 ]]
52
+ )
53
+ expected_edge_index_loop = torch .LongTensor (
54
+ [
55
+ [0 , 0 , 1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 4 , 5 , 5 , 6 , 6 , 6 , 7 , 7 , 7 , 8 , 8 ],
56
+ [0 , 1 , 1 , 0 , 2 , 2 , 1 , 3 , 3 , 2 , 4 , 5 , 6 , 6 , 5 , 7 , 7 , 6 , 8 , 8 , 7 ],
57
+ ]
58
+ )
52
59
53
60
54
61
def test_knn_graph_cpu ():
55
62
from torch_cmspepr import knn_graph
63
+
56
64
# k=2 without loops
57
- edge_index = knn_graph (nodes , 2 , batch , max_radius = .2 )
65
+ edge_index = knn_graph (nodes , 2 , batch , max_radius = 0 .2 )
58
66
print ('Found edge_index:' )
59
67
print (edge_index )
60
68
print ('Expected edge_index:' )
61
69
print (expected_edge_index_noloop )
62
70
assert torch .allclose (edge_index , expected_edge_index_noloop )
63
71
# k=3 with loops
64
- edge_index = knn_graph (nodes , 3 , batch , max_radius = .2 , loop = True )
72
+ edge_index = knn_graph (nodes , 3 , batch , max_radius = 0 .2 , loop = True )
65
73
print ('Found edge_index:' )
66
74
print (edge_index )
67
75
print ('Expected edge_index:' )
68
76
print (expected_edge_index_loop )
69
77
assert torch .allclose (edge_index , expected_edge_index_loop )
70
78
# k=3 with loops
71
- edge_index = knn_graph (nodes , 3 , batch , max_radius = .2 , loop = True , flow = 'target_to_source' )
79
+ edge_index = knn_graph (
80
+ nodes , 3 , batch , max_radius = 0.2 , loop = True , flow = 'target_to_source'
81
+ )
72
82
print ('Found edge_index:' )
73
83
print (edge_index )
74
84
print ('Expected edge_index:' )
75
85
expected = torch .flip (expected_edge_index_loop , [0 ])
76
86
print (expected )
77
87
assert torch .allclose (edge_index , expected )
78
88
89
+
79
90
def test_knn_graph_cpu_1dim ():
80
91
from torch_cmspepr import knn_graph
92
+
81
93
nodes = torch .FloatTensor ([0.1 , 0.2 , 0.5 , 0.6 ])
82
- edge_index = knn_graph (nodes , 2 , max_radius = .2 , loop = True )
83
- expected = torch .LongTensor ([
84
- [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 ],
85
- [0 , 1 , 1 , 0 , 2 , 3 , 3 , 2 ],
86
- ])
94
+ edge_index = knn_graph (nodes , 2 , max_radius = 0.2 , loop = True )
95
+ expected = torch .LongTensor (
96
+ [
97
+ [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 ],
98
+ [0 , 1 , 1 , 0 , 2 , 3 , 3 , 2 ],
99
+ ]
100
+ )
87
101
print ('Found edge_index:' )
88
102
print (edge_index )
89
103
print ('Expected edge_index:' )
90
104
print (expected )
91
105
assert torch .allclose (edge_index , expected )
92
106
107
+
93
108
def test_knn_graph_cuda ():
94
109
from torch_cmspepr import knn_graph
110
+
95
111
gpu = torch .device ('cuda' )
96
112
nodes_cuda , batch_cuda = nodes .to (gpu ), batch .to (gpu )
97
113
# k=2 without loops
98
- edge_index = knn_graph (nodes_cuda , 2 , batch_cuda , max_radius = .2 )
114
+ edge_index = knn_graph (nodes_cuda , 2 , batch_cuda , max_radius = 0 .2 )
99
115
print ('[k=2 no loops] Found edge_index:' )
100
116
print (edge_index )
101
117
print ('Expected edge_index:' )
102
118
print (expected_edge_index_noloop )
103
119
assert torch .allclose (edge_index , expected_edge_index_noloop .to (gpu ))
104
120
# k=3 with loops
105
- edge_index = knn_graph (nodes_cuda , 3 , batch_cuda , max_radius = .2 , loop = True )
121
+ edge_index = knn_graph (nodes_cuda , 3 , batch_cuda , max_radius = 0 .2 , loop = True )
106
122
print ('[k=3 with loops] Found edge_index:' )
107
123
print (edge_index )
108
124
print ('Expected edge_index:' )
109
125
print (expected_edge_index_loop )
110
126
assert torch .allclose (edge_index , expected_edge_index_loop .to (gpu ))
111
127
128
+
112
129
def test_select_knn_cpu ():
113
130
from torch_cmspepr import select_knn
114
- neigh_indices , neigh_dist_sq = select_knn (nodes , k = 3 , batch_x = batch , max_radius = .2 )
131
+
132
+ neigh_indices , neigh_dist_sq = select_knn (nodes , k = 3 , batch_x = batch , max_radius = 0.2 )
115
133
print ('Expected indices:' )
116
134
print (expected_neigh_indices )
117
135
print ('Found indices:' )
@@ -123,13 +141,14 @@ def test_select_knn_cpu():
123
141
assert torch .allclose (neigh_indices , expected_neigh_indices )
124
142
assert torch .allclose (neigh_dist_sq , expected_neigh_dist_sq )
125
143
144
+
126
145
def test_select_knn_cuda ():
127
146
from torch_cmspepr import select_knn
147
+
128
148
device = torch .device ('cuda' )
129
149
neigh_indices , neigh_dist_sq = select_knn (
130
- nodes .to (device ), k = 3 ,
131
- batch_x = batch .to (device ), max_radius = .2
132
- )
150
+ nodes .to (device ), k = 3 , batch_x = batch .to (device ), max_radius = 0.2
151
+ )
133
152
neigh_indices = neigh_indices .cpu ()
134
153
neigh_dist_sq = neigh_dist_sq .cpu ()
135
154
print ('Expected indices:' )
@@ -141,4 +160,4 @@ def test_select_knn_cuda():
141
160
print ('Found dist_sq:' )
142
161
print (neigh_dist_sq )
143
162
assert torch .allclose (neigh_indices , expected_neigh_indices )
144
- assert torch .allclose (neigh_dist_sq , expected_neigh_dist_sq )
163
+ assert torch .allclose (neigh_dist_sq , expected_neigh_dist_sq )
0 commit comments