-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest.py
153 lines (117 loc) · 5.62 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import argparse
from os import listdir
from os.path import join
import PIL.Image as pil_image
import PIL.ImageFilter as pil_image_filter
import cv2
import numpy as np
import pandas as pd
import torch
from torchvision import transforms
from model import Model
from utils import calc_psnr, calc_ssim, set_logging, select_device
from tqdm import tqdm
def main() :
# Argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type = str, default = "SAR-CAM")
parser.add_argument("--weights-dir", type = str, required = True)
parser.add_argument("--clean-image-dir", type = str, required = True)
parser.add_argument("--noisy-image-dir", type = str, required = True)
parser.add_argument("--save-dir", type = str, required = True)
parser.add_argument("--stack-image", action = "store_true")
parser.add_argument("--device", default = "", help = "cuda device, i.e. 0 or 0,1,2,3 or cpu")
args = parser.parse_args()
# Get Current Namespace
print(args)
# Assign Device
set_logging()
device = select_device(args.model_name, args.device)
# Create Model Instance
model = Model(
scale = 2,
in_channels = 1,
channels = 128,
kernel_size = 3,
stride = 1,
dilation = 1,
bias = True
).to(device)
model.load_state_dict(torch.load(args.weights_dir))
# Create Torchvision Transforms Instance
to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()
# Create List Instance for Saving Metrics
image_name_list, psnr_noisy_list, psnr_denoised_list, ssim_noisy_list, ssim_denoised_list = list(), list(), list(), list(), list()
# Assign Device
model.to(device)
# Evaluate Model
model.eval()
with tqdm(total = len(listdir(args.noisy_image_dir))) as pbar :
with torch.no_grad() :
for x in listdir(args.noisy_image_dir) :
# Get Image Path
clean_image_path = join(args.clean_image_dir, x)
noisy_image_path = join(args.noisy_image_dir, x)
# Load Image
clean_image = pil_image.open(clean_image_path)
noisy_image = pil_image.open(noisy_image_path)
# Convert Pillow Image to PyTorch Tensor
tensor_clean_image = to_tensor(clean_image).unsqueeze(0)
tensor_noisy_image = to_tensor(noisy_image).unsqueeze(0).to(device)
# Get Prediction
pred = model(tensor_noisy_image)
# Assign Device into CPU
tensor_noisy_image = tensor_noisy_image.detach().cpu()
pred = pred.detach().cpu()
# Calculate PSNR
psnr_noisy = calc_psnr(tensor_noisy_image, tensor_clean_image).item()
psnr_denoised = calc_psnr(pred, tensor_clean_image).item()
# Calculate SSIM
ssim_noisy = calc_ssim(tensor_noisy_image, tensor_clean_image,size_average = True).item()
ssim_denoised = calc_ssim(pred, tensor_clean_image, size_average = True).item()
# Append Image Name
image_name_list.append(x)
# Append PSNR
psnr_noisy_list.append(psnr_noisy)
psnr_denoised_list.append(psnr_denoised)
# Append SSIM
ssim_noisy_list.append(ssim_noisy)
ssim_denoised_list.append(ssim_denoised)
# Convert PyTorch Tensor to Pillow Image
pred = torch.clamp(pred, min = 0.0, max = 1.0)
pred = to_pil(pred.squeeze(0))
if args.stack_image :
# Get Edge
noisy_image_edge = noisy_image.filter(pil_image_filter.FIND_EDGES)
pred_edge = pred.filter(pil_image_filter.FIND_EDGES)
clean_image_edge = clean_image.filter(pil_image_filter.FIND_EDGES)
# Convert into Numpy Array
noisy_image = np.array(noisy_image, dtype = "uint8")
pred = np.array(pred, dtype = "uint8")
clean_image = np.array(clean_image, dtype = "uint8")
noisy_image_edge = np.array(noisy_image_edge, dtype = "uint8")
pred_edge = np.array(pred_edge, dtype = "uint8")
clean_image_edge = np.array(clean_image_edge, dtype = "uint8")
# Stack Images
stacked_image_clean = np.hstack((noisy_image, pred, clean_image))
stacked_image_edge = np.hstack((noisy_image_edge, pred_edge, clean_image_edge))
stacked_image = np.vstack((stacked_image_clean, stacked_image_edge))
# Save Image
cv2.imwrite(f"{args.save_dir}/{x}", stacked_image)
else :
# Save Image
pred.save(f"{args.save_dir}/{x}")
# Update TQDM Bar
pbar.update()
# Create Dictionary Instance
d = {"Noisy Image PSNR(dB)" : psnr_noisy_list,
"Noisy Image SSIM" : ssim_noisy_list,
"Denoised Image PSNR(dB)" : psnr_denoised_list,
"Denoised Image SSIM" : ssim_denoised_list}
# Create Pandas Dataframe Instance
df = pd.DataFrame(data = d, index = image_name_list)
# Save as CSV Format
df.to_csv(f"{args.save_dir}/image_quality_assessment.csv")
if __name__ == "__main__" :
main()