-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
83 lines (65 loc) · 2.76 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from PIL import Image
import numpy as np
import torchvision
import torch
# colour map
COLORS = [[120, 120, 120], [127, 0, 0], [254, 0, 0], [0, 84, 0], [169, 0, 50], [254, 84, 0], [255, 0, 84], [0, 118, 220], [84, 84, 0], [0, 84, 84], [84, 50, 0], [51, 85, 127], [0, 127, 0], [0, 0, 254], [50, 169, 220], [0, 254, 254], [84, 254, 169], [169, 254, 84], [254, 254, 0], [254, 169, 0], [102, 254, 0], [182, 255, 0]]
# 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
def decode_parsing(labels, num_images=1, num_classes=22, is_pred=False):
"""Decode batch of segmentation masks.
Args:
mask: result of inference after taking argmax.
num_images: number of images to decode from the batch.
num_classes: number of classes to predict (including background).
Returns:
A batch with num_images RGB images of the same size as the input.
"""
pred_labels = labels[:num_images].clone().cpu().data
if is_pred:
pred_labels = torch.argmax(pred_labels, dim=1)
n, h, w = pred_labels.size()
labels_color = torch.zeros([n, 3, h, w], dtype=torch.uint8)
for i, c in enumerate(COLORS):
c0 = labels_color[:, 0, :, :]
c1 = labels_color[:, 1, :, :]
c2 = labels_color[:, 2, :, :]
c0[pred_labels == i] = c[0]
c1[pred_labels == i] = c[1]
c2[pred_labels == i] = c[2]
return labels_color
def inv_preprocess(imgs, num_images):
"""Inverse preprocessing of the batch of images.
Add the mean vector and convert from BGR to RGB.
Args:
imgs: batch of input images.
num_images: number of images to apply the inverse transformations on.
img_mean: vector of mean colour values.
Returns:
The batch of the size num_images with the same spatial dimensions as the input.
"""
rev_imgs = imgs[:num_images].clone().cpu().data
rev_normalize = NormalizeInverse(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
for i in range(num_images):
rev_imgs[i] = rev_normalize(rev_imgs[i])
return rev_imgs
class NormalizeInverse(torchvision.transforms.Normalize):
"""
Undoes the normalization and returns the reconstructed images in the input domain.
"""
def __init__(self, mean, std):
mean = torch.as_tensor(mean)
std = torch.as_tensor(std)
std_inv = 1 / (std + 1e-7)
mean_inv = -mean * std_inv
super().__init__(mean=mean_inv, std=std_inv)
class AverageMeter:
def __init__(self, name=None):
self.name = name
self.reset()
def reset(self):
self.sum = self.count = self.avg = 0
def update(self, val, n=1):
self.sum += val * n
self.count += n
self.avg = self.sum / self.count