-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_seg.py
107 lines (86 loc) · 3.17 KB
/
eval_seg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import warnings
warnings.filterwarnings('ignore')
import argparse
from model import create_model
import torch.backends.cudnn as cudnn
from data.DESC_dataset import DatasetProvider
from utils import update_opt
from copy import deepcopy
from model.encoderdecoder_model import EncoderDecoder
cudnn.benchmark = True
import ttach as tta
try:
from mmseg.ops import resize
except:
from mmseg.models.utils import resize
import torch
from trainer.seg_trainer import semseg_compute_confusion, semseg_accum_confusion_to_iou
def forward(self, img):
"""Forward function for training.
Args:
img (Tensor): Input images.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(img)
x_main = self._decode_head_forward(x)
x_main = resize(
input=x_main,
size=img.shape[2:],
mode='bilinear',
align_corners=self.align_corners
)
return x_main
EncoderDecoder.forward = forward
def semseg_accum_confusion_to_macc(confusion_accum):
conf = confusion_accum.double()
diag = conf.diag()
acc = 100 * diag / conf.sum(dim=1).clamp(min=1e-12)
return acc.mean()
@torch.no_grad()
def main(args):
args, opt = update_opt(args.opt,args)
backbone = create_model(opt["network"],opt['logger']['name'])
model = EncoderDecoder(
backbone,
opt["head_main"],
opt["head_aux"]
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")['model']
model.load_state_dict(checkpoint)
model.eval()
model.cuda()
transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.Scale(scales=[1,1.5,2],interpolation="bilinear", align_corners = True),
]
)
model = tta.SegmentationTTAWrapper(model, transforms, merge_mode='mean')
dataset_opt = deepcopy(opt["datasets"])
eval_opt = deepcopy(dataset_opt)
eval_opt['phase'] = "eval"
val_dataset = DatasetProvider(**dataset_opt,mode='val').get_dataset()
val_dataset = torch.utils.data.DataLoader(val_dataset, batch_size=4, num_workers=2, drop_last=False, pin_memory=True, persistent_workers = True, shuffle=False)
confusion = 0
for data in val_dataset:
event = data["event_voxel"].cuda()
label = data["label"].cuda()
pred_main = model(event)
confusion += semseg_compute_confusion(pred_main.argmax(1),label).float()
acc = semseg_accum_confusion_to_macc(confusion).item()
iou = (semseg_accum_confusion_to_iou(confusion)[0]).item()
print(args.opt,args.checkpoint)
print(acc, iou)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", default = 1, type = int)
parser.add_argument("--acce", default = "ddp", type = str)
parser.add_argument("--num_nodes", default = 1, type = int)
parser.add_argument("--checkpoint", default = None, type = str)
parser.add_argument("--resume", action="store_true")
parser.add_argument('--opt', type=str, default = "", help='Path to option YAML file.')
args = parser.parse_args()
main(args)