-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsample.py
54 lines (44 loc) · 1.48 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
from src.networks import TransferModel
import matplotlib.pyplot as plt
from src.train_utils.imaging_utils import deprocess_input
import argparse
import glob
def sample(args):
model = TransferModel((args.height, args.width, 3))
model.build(input_shape=(None, args.height, args.width, 3))
model.load_weights(args.model)
img_paths = glob.glob(f'{args.img_path}/*.{args.img_ext}')
for img_path in img_paths:
img = plt.imread(img_path)
img = tf.keras.applications.vgg16.preprocess_input(img)
img = tf.image.resize(img, size=(args.height, args.width))
img = tf.expand_dims(img, axis=0)
styled_img, _ = model(img)
plt.imshow(deprocess_input(styled_img[0]))
plt.axis('off')
plt.savefig(img_path.replace(f'.{args.img_ext}', '_styled.png'),
dpi=80,
bbox_inches='tight',
pad_inches=0)
plt.close()
parser = argparse.ArgumentParser()
parser.add_argument('--img_ext', type=str, default='png', required=False)
parser.add_argument('--width', type=int, default=640, required=False)
parser.add_argument('--height', type=int, default=320, required=False)
parser.add_argument(
'--model',
default=False,
type=str,
required=True,
)
parser.add_argument(
'--img_path',
default='/home/qbeer/pics_vesuvio',
type=str,
required=False,
)
args = parser.parse_args()
sample(args)