-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinfer.py
52 lines (42 loc) · 1.71 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import pytorch_lightning as pl
import torch
from torchvision.transforms import functional as TF
from PIL import Image
import numpy as np
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from model import LitRT4KSR_Rep
from utils import reparameterize, tensor2uint
import config
model_path = config.checkpoint_path_infer
lr_image_path = config.infer_lr_image_path
save_path = config.infer_save_path
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") \
if config.device == "auto" else torch.device(config.device)
print("Using device:", device)
litmodel = LitRT4KSR_Rep.load_from_checkpoint(
checkpoint_path=model_path,
config=config,
map_location=device
)
if config.infer_reparameterize: # reparameterize model
litmodel.model = reparameterize(config, litmodel.model, device, save_rep_checkpoint=False)
litmodel.model.to(device)
litmodel.eval()
# Convert image to tensor and move to device
image_name = os.path.basename(lr_image_path)
lr_image = Image.open(lr_image_path).convert("RGB")
lr_sample = TF.to_tensor(np.array(lr_image) / 255.0).unsqueeze(0).float().to(device)
# Inference
with torch.no_grad():
sr_sample = litmodel.predict_step(lr_sample)
# Convert tensor to image and save
sr_sample = tensor2uint(sr_sample * 255.0)
image_sr_PIL = Image.fromarray(sr_sample)
if not os.path.exists(save_path):
os.makedirs(save_path)
image_sr_PIL.save(os.path.join(save_path, image_name))
print("Inference done. Image saved to: {}".format(os.path.join(save_path, image_name)))
if __name__ == "__main__":
main()