-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
138 lines (122 loc) · 4.68 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from pix2pix.src.model.models import generator_unet_upsampling
from pix2pix.src.model import models
from PIL import Image
import h5py
import cv2
import sys
import os
import numpy as np
import cv2
from keras.applications.resnet50 import preprocess_input
from keras.models import load_model
from config import CONFIG
#from inf_cycle import test, test_dep
img_dim = 224
output_path = 'static/results'
# Make dictionary
if not CONFIG['development']:
# print 'Loading p2d_model'
# p2d_model = load_model('weights/model_resglass.h5')
# print 'Loading d2p_model'
# d2p_model = load_model('weights/model_resglass.h5')
model_list = {
'pix2depth':{
'pix2pix' : model.load_weights('./weights/p2d_pix2pix.h5'),
'CycleGAN':load_model('./weights/p2d_cycle.h5'),
'CNN': load_model('./weights/p2d_cnn.h5'),
},
'depth2pix':{
'pix2pix' : model.load_model('./weights/d2p_pix2pix.h5'),
'CycleGAN':load_model('./weights/d2p_cycle.h5'),
}
}
def pix2depth(path, model):
if model_name == "generator_unet_upsampling":
model = generator_unet_upsampling(img_dim, bn_mode, model_name=model_name)
model.summary()
# Load generator model
generator_model = models.load("generator_unet_%s" % generator,
img_dim,
nb_patch,
bn_mode,
use_mbd,
batch_size)
# model_name = generator_unet_upsampling
originalImage = cv2.imread(path)
loaded_model = model_list['pix2depth'][model]
file_name = model+'_'+path.split('/')[-1]
output_file = os.path.join(output_path,file_name)
if model =='CNN':
originalImage = cv2.resize(originalImage,(img_dim,img_dim))
x = preprocess_input(originalImage/1.)
elif model == 'CycleGAN':
#test(path)
os.system('cp gautam/inf_results/imgs/fakeA_0_0.jpg %s' % output_file)
else:
originalImage = cv2.resize(originalImage,(256,256))
x = originalImage/255.
if not model == 'CycleGAN':
p1 = get_depth_map(x, loaded_model)
cv2.imwrite(output_file,p1)
return output_file
def depth2pix(path,model):
model_name = 'd2p'
originalImage = cv2.imread(path)
loaded_model = model_list['depth2pix'][model]
file_name = model+'_'+path.split('/')[-1]
output_file = os.path.join(output_path,file_name)
if model =='CNN':
img_dim = 256
originalImage = cv2.resize(originalImage,(img_dim,img_dim))
x = preprocess_input(originalImage/1.)
elif model == 'CycleGAN':
#test_dep(path)
os.system('cp gautam/inf_results/imgs/fakeB_0_0.jpg %s' % output_file)
else:
originalImage = cv2.resize(originalImage,(256,256))
x = originalImage/255.
if not model == 'CycleGAN':
p1 = get_depth_map(x, loaded_model)
cv2.imwrite(output_file,p1)
return output_file
def blur_effect(image, depthImage, outputPath):
try:
if len(depthImage.shape) == 3:
depthImage = np.mean(depthImage, axis=-1).astype(int)
print(depthImage.shape)
(h, w) = depthImage.shape
image = cv2.resize(image, (h,w))
blurredImage = cv2.GaussianBlur(image,(5,5),0)
print('b')
# Need path to depth Image
thresh = 200
maskImage = cv2.threshold(depthImage, thresh, 255, cv2.THRESH_BINARY)[1]
print(maskImage.shape)
print(h,w)
new_image = np.zeros((h, w, 3),dtype=np.int)
for i in range(len(maskImage)):
for j in range(len(maskImage[i])):
if maskImage[i][j] ==255.0:
new_image[i,j,0] = blurredImage[i,j,0]
new_image[i,j,1] = blurredImage[i,j,1]
new_image[i,j,2] = blurredImage[i,j,2]
else:
new_image[i,j,0] = image[i,j,0]
new_image[i,j,1] = image[i,j,1]
new_image[i,j,2] = image[i,j,2]
cv2.imwrite(outputPath, new_image)
return True
except Exception as e:
print(e)
return False
def get_depth_map(input_image, model):
pred_dep = model.predict(np.array([input_image]), batch_size=1)[0]*255.
return pred_dep
def portrait_mode(path, model):
originalImage = cv2.imread(path)
file_name = model+'_'+path.split('/')[-1]
output_file = pix2depth(path, model)
portrait_out_path = os.path.join(output_path, 'portrait_'+file_name)
p1 = cv2.imread(output_file, 0)
if blur_effect(originalImage, p1, portrait_out_path):
return portrait_out_path