-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
70 lines (55 loc) · 2.13 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import paddle
import paddle.nn as nn
from paddle.io import DataLoader
from paddle.metric import Accuracy
from paddle.optimizer import Adam
from paddle.optimizer.lr import StepDecay
from data import ModelNetDataset
from model import CrossEntropyMatrixRegularization, PointNetClassifier
def parse_args():
parser = argparse.ArgumentParser("Train")
parser.add_argument(
"--batch_size", type=int, default=32, help="batch size in training"
)
parser.add_argument("--num_category", type=int, default=40, help="ModelNet10/40")
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="learning rate in training"
)
parser.add_argument("--num_point", type=int, default=1024, help="point number")
parser.add_argument("--max_epochs", type=int, default=200, help="max epochs")
parser.add_argument("--num_workers", type=int, default=32, help="num wrokers")
parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay")
parser.add_argument("--log_freq", type=int, default=1)
parser.add_argument("--verbose", type=int, default=1)
parser.add_argument("--model_path", type=str, default="pointnet.pdparams")
parser.add_argument(
"--data_dir", type=str, default="modelnet40_normal_resampled",
)
return parser.parse_args()
def test(args):
test_data = ModelNetDataset(args.data_dir, split="test", num_point=args.num_point)
test_loader = DataLoader(
test_data,
shuffle=False,
num_workers=args.num_workers,
batch_size=args.batch_size,
)
model = PointNetClassifier()
loss_fn = CrossEntropyMatrixRegularization()
metrics = Accuracy()
model_state_dict = paddle.load(args.model_path)
model.set_state_dict(model_state_dict)
metrics.reset()
model.eval()
for _, data in enumerate(test_loader):
x, y = data
pred, _, _ = model(x)
correct = metrics.compute(pred, y)
metrics.update(correct)
test_acc = metrics.accumulate()
print("Test Accuracy: {}".format(test_acc))
if __name__ == "__main__":
args = parse_args()
print(args)
test(args)