-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathprepare.py
73 lines (51 loc) · 2.21 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from torchvision.transforms import transforms
def train(args):
h5_file = h5py.File(args.output_path, 'w')
lr_group = h5_file.create_group('lr')
hr_group = h5_file.create_group('hr')
image_list = sorted(glob.glob('{}/*'.format(args.images_dir)))
patch_idx = 0
for i, image_path in enumerate(image_list):
hr = pil_image.open(image_path).convert('RGB')
for hr in transforms.FiveCrop(size=(hr.height // 2, hr.width // 2))(hr):
hr = hr.resize(((hr.width // args.scale) * args.scale, (hr.height // args.scale) * args.scale), resample=pil_image.BICUBIC)
lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)
hr = np.array(hr)
lr = np.array(lr)
lr_group.create_dataset(str(patch_idx), data=lr)
hr_group.create_dataset(str(patch_idx), data=hr)
patch_idx += 1
print(i, patch_idx, image_path)
h5_file.close()
def eval(args):
h5_file = h5py.File(args.output_path, 'w')
lr_group = h5_file.create_group('lr')
hr_group = h5_file.create_group('hr')
for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):
hr = pil_image.open(image_path).convert('RGB')
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr.width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
hr = np.array(hr)
lr = np.array(lr)
lr_group.create_dataset(str(i), data=lr)
hr_group.create_dataset(str(i), data=hr)
print(i)
h5_file.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--images-dir', type=str, required=True)
parser.add_argument('--output-path', type=str, required=True)
parser.add_argument('--scale', type=int, default=4)
parser.add_argument('--eval', action='store_true')
args = parser.parse_args()
if not args.eval:
train(args)
else:
eval(args)