-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSup_predict_5comps.py
129 lines (89 loc) · 4.03 KB
/
Sup_predict_5comps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# coding: utf-8
# In[1]:
import argparse
import os
import sys
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# get_ipython().run_line_magic('matplotlib', 'inline')
import skimage.io as io
from skimage.transform import resize
import numpy as np
import tensorflow as tf
from func.data_generator import DataGenerator
from func.unet_model import NeuralNetwork
from func.tool import get_fname
from func.plot import plt_result
########
# In[2]:
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', dest='gpu', default='6')
parser.add_argument('--test_mode', dest='test_mode', default= False , type=bool)
### model using
parser.add_argument('--model_dir', dest='model_dir', default='model/Unet_5comps/5comps_model/')
### data
parser.add_argument('--XX_DIR', dest='XX_DIR', default='data/processed/finaluse/A_rmbg_00_img/')
parser.add_argument('--image-h', dest='im_h', default=256, type=int)
parser.add_argument('--image-w', dest='im_w', default=256, type=int)
parser.add_argument('--image-d', dest='im_c', default=3, type=int)
parser.add_argument('--num_class', dest='num_class', default= 5, type=int)
parser.add_argument('--keep_prob', dest='keep_prob', default= 1, type=float)
### model
parser.add_argument('--bz', '--batch-size', dest='bz', type=int, default=16)
### return parser
args = parser.parse_args()
### set GPU
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
# In[3]:
md = NeuralNetwork(args = args)
md.build_graph()
md.attach_saver()
# In[4]:
Xname = os.listdir(args.XX_DIR)
Xpath = [os.path.join(args.XX_DIR, name) for name in Xname]
datagen = DataGenerator(input_shape=(args.im_h, args.im_w, args.im_c))
pred_gen = datagen.get_tesri_data(Xpath, bz=args.bz)
pred_iter = int(np.ceil(len(Xpath)/(args.bz)))
# In[5]:
model_dir = args.model_dir
save_dir = model_dir
checkpoint_path = model_dir + 'model_ckpt'
with tf.Session(graph=md.graph) as sess:
sess.run(tf.global_variables_initializer()) # Variable initialization.
meta_to_restore = checkpoint_path+'.meta'
saver = tf.train.import_meta_graph(meta_to_restore)
saver.restore(sess,checkpoint_path)
print('Model Restored')
pred_loss_collector = []
for pred_batch_i in range(pred_iter):
print('\r[Predict]-----pred-mini-Batch ({}/{})'.format(pred_batch_i+1, pred_iter), end='\r')
x_pred_batch, path = next(pred_gen)
pred_batch = sess.run([md.y_pred_tf],
feed_dict = {md.x_data_tf: x_pred_batch})
for i in range(len(pred_batch[0])):
fname = get_fname(path[i])
ori_img = io.imread(path[i])
ori_img = resize(ori_img, output_shape=(256,256,3))
mask_list = []
part_name = ['A_rmbg_01_Body', 'A_rmbg_02_Left_fore', 'A_rmbg_03_Right_fore', 'A_rmbg_04_Left_hind', 'A_rmbg_05_Right_hind']
for ch in range(args.num_class):
m_mask = pred_batch[0][i][:,:,ch]
norm_mask = (m_mask-m_mask.min())/ (m_mask.max()-m_mask.min())
norm_mask3 = np.stack([norm_mask,norm_mask,norm_mask], axis = 2)
bin_mask3 = np.where(norm_mask3 > 0.5, 1.0, 0.0)
mask_list.append(bin_mask3)
save_to_mask = os.path.join(save_dir ,'Predict_mask', part_name[ch])
if not os.path.exists(save_to_mask):
os.makedirs(save_to_mask)
io.imsave(os.path.join(save_to_mask,'%s.png' % (fname)), bin_mask3[:,:,0])
# w_mask = 1-bin_mask3
# w_img = (bin_mask3 * ori_img)+w_mask
img_list =[ori_img]+(mask_list)
title_list = ['Original image', '', '', '', '', '']
fig = plt_result(img_list, title_list)
save_to_check = os.path.join(save_dir , 'Predict_checking')
if not os.path.exists(save_to_check):
os.makedirs(save_to_check)
fig.savefig(os.path.join(save_to_check,'%s.png' % (fname)), dpi=100, format='png',bbox_inches='tight' )