Skip to content

Commit e0d82f1

Browse files
author
Thomas Klijnsma
committed
formatting and linting
1 parent 713b03f commit e0d82f1

File tree

4 files changed

+160
-120
lines changed

4 files changed

+160
-120
lines changed

tests/test_extensions.py

+48-40
Original file line numberDiff line numberDiff line change
@@ -3,52 +3,59 @@
33

44
# 4 points on a diagonal line with d^2 = 0.1^2+0.1^2 = 0.02 between them.
55
# 1 point very far away.
6-
nodes = torch.FloatTensor([
7-
# Event 0
8-
[.1, .1],
9-
[.2, .2],
10-
[.3, .3],
11-
[.4, .4],
12-
[100., 100.],
13-
# Event 1
14-
[.1, .1],
15-
[.2, .2],
16-
[.3, .3],
17-
[.4, .4]
18-
])
6+
nodes = torch.FloatTensor(
7+
[
8+
# Event 0
9+
[0.1, 0.1],
10+
[0.2, 0.2],
11+
[0.3, 0.3],
12+
[0.4, 0.4],
13+
[100.0, 100.0],
14+
# Event 1
15+
[0.1, 0.1],
16+
[0.2, 0.2],
17+
[0.3, 0.3],
18+
[0.4, 0.4],
19+
]
20+
)
1921
row_splits = torch.IntTensor([0, 5, 9])
2022
mask: torch.Tensor = torch.ones(nodes.shape[0], dtype=torch.int32)
21-
max_radius = .2
23+
max_radius = 0.2
2224
mask_mode = 1
23-
k = 3 # Including connection with self
25+
k = 3 # Including connection with self
2426

2527
# Expected output for k=3, max_radius=0.2 (with loop)
2628
# Always a connection with self, which has distance 0.0
27-
expected_neigh_indices = torch.IntTensor([
28-
[ 0, 1, -1],
29-
[ 1, 0, 2],
30-
[ 2, 1, 3],
31-
[ 3, 2, -1],
32-
[ 4, -1, -1],
33-
[ 5, 6, -1],
34-
[ 6, 5, 7],
35-
[ 7, 6, 8],
36-
[ 8, 7, -1]
37-
])
38-
expected_neigh_dist_sq = torch.FloatTensor([
39-
[0.0, 0.02, 0.00],
40-
[0.0, 0.02, 0.02],
41-
[0.0, 0.02, 0.02],
42-
[0.0, 0.02, 0.00],
43-
[0.0, 0.00, 0.00],
44-
[0.0, 0.02, 0.00],
45-
[0.0, 0.02, 0.02],
46-
[0.0, 0.02, 0.02],
47-
[0.0, 0.02, 0.00]
48-
])
29+
expected_neigh_indices = torch.IntTensor(
30+
[
31+
[0, 1, -1],
32+
[1, 0, 2],
33+
[2, 1, 3],
34+
[3, 2, -1],
35+
[4, -1, -1],
36+
[5, 6, -1],
37+
[6, 5, 7],
38+
[7, 6, 8],
39+
[8, 7, -1],
40+
]
41+
)
42+
expected_neigh_dist_sq = torch.FloatTensor(
43+
[
44+
[0.0, 0.02, 0.00],
45+
[0.0, 0.02, 0.02],
46+
[0.0, 0.02, 0.02],
47+
[0.0, 0.02, 0.00],
48+
[0.0, 0.00, 0.00],
49+
[0.0, 0.02, 0.00],
50+
[0.0, 0.02, 0.02],
51+
[0.0, 0.02, 0.02],
52+
[0.0, 0.02, 0.00],
53+
]
54+
)
4955

5056
SO_DIR = osp.dirname(osp.dirname(osp.abspath(__file__)))
5157

58+
5259
def test_select_knn_op_cpu():
5360
torch.ops.load_library(osp.join(SO_DIR, 'select_knn_cpu.so'))
5461
neigh_indices, neigh_dist_sq = torch.ops.select_knn_cpu.select_knn_cpu(
@@ -58,7 +65,7 @@ def test_select_knn_op_cpu():
5865
k,
5966
max_radius,
6067
mask_mode,
61-
)
68+
)
6269
print('Expected indices:')
6370
print(expected_neigh_indices)
6471
print('Found indices:')
@@ -70,6 +77,7 @@ def test_select_knn_op_cpu():
7077
assert torch.allclose(neigh_indices, expected_neigh_indices)
7178
assert torch.allclose(neigh_dist_sq, expected_neigh_dist_sq)
7279

80+
7381
def test_select_knn_op_cuda():
7482
gpu = torch.device('cuda')
7583
torch.ops.load_library(osp.join(SO_DIR, 'select_knn_cuda.so'))
@@ -80,7 +88,7 @@ def test_select_knn_op_cuda():
8088
k,
8189
max_radius,
8290
mask_mode,
83-
)
91+
)
8492
print('Expected indices:')
8593
print(expected_neigh_indices)
8694
print('Found indices:')
@@ -90,4 +98,4 @@ def test_select_knn_op_cuda():
9098
print('Found dist_sq:')
9199
print(neigh_dist_sq)
92100
assert torch.allclose(neigh_indices.cpu(), expected_neigh_indices)
93-
assert torch.allclose(neigh_dist_sq.cpu(), expected_neigh_dist_sq)
101+
assert torch.allclose(neigh_dist_sq.cpu(), expected_neigh_dist_sq)

tests/test_knn.py

+78-59
Original file line numberDiff line numberDiff line change
@@ -2,116 +2,134 @@
22

33
# 4 points on a diagonal line with d^2 = 0.1^2+0.1^2 = 0.02 between them.
44
# 1 point very far away.
5-
nodes = torch.FloatTensor([
6-
# Event 0
7-
[.1, .1],
8-
[.2, .2],
9-
[.3, .3],
10-
[.4, .4],
11-
[100., 100.],
12-
# Event 1
13-
[.1, .1],
14-
[.2, .2],
15-
[.3, .3],
16-
[.4, .4]
17-
])
18-
batch = torch.LongTensor([0,0,0,0,0,1,1,1,1])
5+
nodes = torch.FloatTensor(
6+
[
7+
# Event 0
8+
[0.1, 0.1],
9+
[0.2, 0.2],
10+
[0.3, 0.3],
11+
[0.4, 0.4],
12+
[100.0, 100.0],
13+
# Event 1
14+
[0.1, 0.1],
15+
[0.2, 0.2],
16+
[0.3, 0.3],
17+
[0.4, 0.4],
18+
]
19+
)
20+
batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1])
1921

2022
# Expected output for k=3, max_radius=0.2 (with loop)
2123
# Always a connection with self, which has distance 0.0
22-
expected_neigh_indices = torch.IntTensor([
23-
[ 0, 1, -1],
24-
[ 1, 0, 2],
25-
[ 2, 1, 3],
26-
[ 3, 2, -1],
27-
[ 4, -1, -1],
28-
[ 5, 6, -1],
29-
[ 6, 5, 7],
30-
[ 7, 6, 8],
31-
[ 8, 7, -1]
32-
])
33-
expected_neigh_dist_sq = torch.FloatTensor([
34-
[0.0, 0.02, 0.00],
35-
[0.0, 0.02, 0.02],
36-
[0.0, 0.02, 0.02],
37-
[0.0, 0.02, 0.00],
38-
[0.0, 0.00, 0.00],
39-
[0.0, 0.02, 0.00],
40-
[0.0, 0.02, 0.02],
41-
[0.0, 0.02, 0.02],
42-
[0.0, 0.02, 0.00]
43-
])
44-
expected_edge_index_noloop = torch.LongTensor([
45-
[0, 1, 1, 2, 2, 3, 5, 6, 6, 7, 7, 8],
46-
[1, 0, 2, 1, 3, 2, 6, 5, 7, 6, 8, 7]
47-
])
48-
expected_edge_index_loop = torch.LongTensor([
49-
[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8],
50-
[0, 1, 1, 0, 2, 2, 1, 3, 3, 2, 4, 5, 6, 6, 5, 7, 7, 6, 8, 8, 7]
51-
])
24+
expected_neigh_indices = torch.IntTensor(
25+
[
26+
[0, 1, -1],
27+
[1, 0, 2],
28+
[2, 1, 3],
29+
[3, 2, -1],
30+
[4, -1, -1],
31+
[5, 6, -1],
32+
[6, 5, 7],
33+
[7, 6, 8],
34+
[8, 7, -1],
35+
]
36+
)
37+
expected_neigh_dist_sq = torch.FloatTensor(
38+
[
39+
[0.0, 0.02, 0.00],
40+
[0.0, 0.02, 0.02],
41+
[0.0, 0.02, 0.02],
42+
[0.0, 0.02, 0.00],
43+
[0.0, 0.00, 0.00],
44+
[0.0, 0.02, 0.00],
45+
[0.0, 0.02, 0.02],
46+
[0.0, 0.02, 0.02],
47+
[0.0, 0.02, 0.00],
48+
]
49+
)
50+
expected_edge_index_noloop = torch.LongTensor(
51+
[[0, 1, 1, 2, 2, 3, 5, 6, 6, 7, 7, 8], [1, 0, 2, 1, 3, 2, 6, 5, 7, 6, 8, 7]]
52+
)
53+
expected_edge_index_loop = torch.LongTensor(
54+
[
55+
[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8],
56+
[0, 1, 1, 0, 2, 2, 1, 3, 3, 2, 4, 5, 6, 6, 5, 7, 7, 6, 8, 8, 7],
57+
]
58+
)
5259

5360

5461
def test_knn_graph_cpu():
5562
from torch_cmspepr import knn_graph
63+
5664
# k=2 without loops
57-
edge_index = knn_graph(nodes, 2, batch, max_radius=.2)
65+
edge_index = knn_graph(nodes, 2, batch, max_radius=0.2)
5866
print('Found edge_index:')
5967
print(edge_index)
6068
print('Expected edge_index:')
6169
print(expected_edge_index_noloop)
6270
assert torch.allclose(edge_index, expected_edge_index_noloop)
6371
# k=3 with loops
64-
edge_index = knn_graph(nodes, 3, batch, max_radius=.2, loop=True)
72+
edge_index = knn_graph(nodes, 3, batch, max_radius=0.2, loop=True)
6573
print('Found edge_index:')
6674
print(edge_index)
6775
print('Expected edge_index:')
6876
print(expected_edge_index_loop)
6977
assert torch.allclose(edge_index, expected_edge_index_loop)
7078
# k=3 with loops
71-
edge_index = knn_graph(nodes, 3, batch, max_radius=.2, loop=True, flow='target_to_source')
79+
edge_index = knn_graph(
80+
nodes, 3, batch, max_radius=0.2, loop=True, flow='target_to_source'
81+
)
7282
print('Found edge_index:')
7383
print(edge_index)
7484
print('Expected edge_index:')
7585
expected = torch.flip(expected_edge_index_loop, [0])
7686
print(expected)
7787
assert torch.allclose(edge_index, expected)
7888

89+
7990
def test_knn_graph_cpu_1dim():
8091
from torch_cmspepr import knn_graph
92+
8193
nodes = torch.FloatTensor([0.1, 0.2, 0.5, 0.6])
82-
edge_index = knn_graph(nodes, 2, max_radius=.2, loop=True)
83-
expected = torch.LongTensor([
84-
[0, 0, 1, 1, 2, 2, 3, 3],
85-
[0, 1, 1, 0, 2, 3, 3, 2],
86-
])
94+
edge_index = knn_graph(nodes, 2, max_radius=0.2, loop=True)
95+
expected = torch.LongTensor(
96+
[
97+
[0, 0, 1, 1, 2, 2, 3, 3],
98+
[0, 1, 1, 0, 2, 3, 3, 2],
99+
]
100+
)
87101
print('Found edge_index:')
88102
print(edge_index)
89103
print('Expected edge_index:')
90104
print(expected)
91105
assert torch.allclose(edge_index, expected)
92106

107+
93108
def test_knn_graph_cuda():
94109
from torch_cmspepr import knn_graph
110+
95111
gpu = torch.device('cuda')
96112
nodes_cuda, batch_cuda = nodes.to(gpu), batch.to(gpu)
97113
# k=2 without loops
98-
edge_index = knn_graph(nodes_cuda, 2, batch_cuda, max_radius=.2)
114+
edge_index = knn_graph(nodes_cuda, 2, batch_cuda, max_radius=0.2)
99115
print('[k=2 no loops] Found edge_index:')
100116
print(edge_index)
101117
print('Expected edge_index:')
102118
print(expected_edge_index_noloop)
103119
assert torch.allclose(edge_index, expected_edge_index_noloop.to(gpu))
104120
# k=3 with loops
105-
edge_index = knn_graph(nodes_cuda, 3, batch_cuda, max_radius=.2, loop=True)
121+
edge_index = knn_graph(nodes_cuda, 3, batch_cuda, max_radius=0.2, loop=True)
106122
print('[k=3 with loops] Found edge_index:')
107123
print(edge_index)
108124
print('Expected edge_index:')
109125
print(expected_edge_index_loop)
110126
assert torch.allclose(edge_index, expected_edge_index_loop.to(gpu))
111127

128+
112129
def test_select_knn_cpu():
113130
from torch_cmspepr import select_knn
114-
neigh_indices, neigh_dist_sq = select_knn(nodes, k=3, batch_x=batch, max_radius=.2)
131+
132+
neigh_indices, neigh_dist_sq = select_knn(nodes, k=3, batch_x=batch, max_radius=0.2)
115133
print('Expected indices:')
116134
print(expected_neigh_indices)
117135
print('Found indices:')
@@ -123,13 +141,14 @@ def test_select_knn_cpu():
123141
assert torch.allclose(neigh_indices, expected_neigh_indices)
124142
assert torch.allclose(neigh_dist_sq, expected_neigh_dist_sq)
125143

144+
126145
def test_select_knn_cuda():
127146
from torch_cmspepr import select_knn
147+
128148
device = torch.device('cuda')
129149
neigh_indices, neigh_dist_sq = select_knn(
130-
nodes.to(device), k=3,
131-
batch_x=batch.to(device), max_radius=.2
132-
)
150+
nodes.to(device), k=3, batch_x=batch.to(device), max_radius=0.2
151+
)
133152
neigh_indices = neigh_indices.cpu()
134153
neigh_dist_sq = neigh_dist_sq.cpu()
135154
print('Expected indices:')
@@ -141,4 +160,4 @@ def test_select_knn_cuda():
141160
print('Found dist_sq:')
142161
print(neigh_dist_sq)
143162
assert torch.allclose(neigh_indices, expected_neigh_indices)
144-
assert torch.allclose(neigh_dist_sq, expected_neigh_dist_sq)
163+
assert torch.allclose(neigh_dist_sq, expected_neigh_dist_sq)

torch_cmspepr/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# ruff: noqa: E402
12
import os.path as osp
23
import logging
34
import torch
@@ -33,6 +34,7 @@ def setup_logger(name: str = "cmspepr") -> logging.Logger:
3334
logger.addHandler(handler)
3435
return logger
3536

37+
3638
logger = setup_logger()
3739

3840

@@ -43,6 +45,7 @@ def load_ops(so_file):
4345
else:
4446
torch.ops.load_library(so_file)
4547

48+
4649
THISDIR = osp.dirname(osp.abspath(__file__))
4750
load_ops(osp.join(THISDIR, "../select_knn_cpu.so"))
4851
load_ops(osp.join(THISDIR, "../select_knn_cuda.so"))

0 commit comments

Comments
 (0)