-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
77 lines (63 loc) · 3.02 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import glob
import pytorch_lightning as pl
from sympy import li
import torch
from torchvision.transforms import functional as TF
from PIL import Image
import numpy as np
from tqdm import tqdm
from utils import reparameterize
from model import LitRT4KSR_Rep, rt4ksr_rep
from utils import calculate_psnr, calculate_ssim, tensor2uint
import config
model_path = config.checkpoint_path_eval
lr_image_dir = config.eval_lr_image_dir
hr_image_dir = config.eval_hr_image_dir
save_path = config.val_save_path
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") \
if config.device == "auto" else torch.device(config.device)
print("Using device:", device)
litmodel = LitRT4KSR_Rep.load_from_checkpoint(
checkpoint_path=model_path,
config=config,
map_location=device
)
if config.eval_reparameterize: # reparameterize model
litmodel.model = reparameterize(config, litmodel.model, device, save_rep_checkpoint=False)
litmodel.model.to(device)
litmodel.eval()
list_lr_image_path = glob.glob(os.path.join(lr_image_dir, "*.png"))
list_hr_image_path = glob.glob(os.path.join(hr_image_dir, "*.png"))
psnr_RGB_lst, ssim_RGB_lst, psnr_Y_lst, ssim_Y_lst = [], [], [], []
# Evaluation
list_pair_lr_hr_image_path = list(zip(list_lr_image_path, list_hr_image_path))
for lr_image_path, hr_image_path in tqdm(list_pair_lr_hr_image_path, desc="Eval"):
image_name = os.path.basename(lr_image_path)
# Convert image to tensor and move to device
lr_image = Image.open(lr_image_path).convert("RGB")
hr_image = Image.open(hr_image_path).convert("RGB")
lr_sample = TF.to_tensor(np.array(lr_image) / 255.0).unsqueeze(0).float().to(device)
hr_sample = TF.to_tensor(np.array(hr_image) / 255.0).unsqueeze(0).float().to(device)
with torch.no_grad():
sr_sample = litmodel.predict_step(lr_sample)
sr_sample = tensor2uint(sr_sample * 255.0)
hr_sample = tensor2uint(hr_sample * 255.0)
# Calculate PSNR and SSIM
psnr_RGB_lst.append(calculate_psnr(sr_sample, hr_sample, crop_border=0, test_y_channel=False))
ssim_RGB_lst.append(calculate_ssim(sr_sample, hr_sample, crop_border=0, test_y_channel=False))
psnr_Y_lst.append(calculate_psnr(sr_sample, hr_sample, crop_border=0, test_y_channel=True))
ssim_Y_lst.append(calculate_ssim(sr_sample, hr_sample, crop_border=0, test_y_channel=True))
# save image
image_sr_PIL = Image.fromarray(sr_sample)
if not os.path.exists(save_path):
os.makedirs(save_path)
image_sr_PIL.save(os.path.join(save_path, image_name))
# Show results
print("Average PSNR (RGB):", sum(psnr_RGB_lst) / len(psnr_RGB_lst))
print("Average PSNR (Y) :", sum(psnr_Y_lst) / len(psnr_Y_lst))
print("Average SSIM (RGB):", sum(ssim_RGB_lst) / len(ssim_RGB_lst))
print("Average SSIM (Y) :", sum(ssim_Y_lst) / len(ssim_Y_lst))
if __name__ == "__main__":
main()