-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathdemo.py
51 lines (38 loc) · 1.62 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import torch
import torch.nn.parallel
from models import modules, net, resnet, densenet, senet
import numpy as np
import loaddata_demo as loaddata
import pdb
import matplotlib.image
import matplotlib.pyplot as plt
plt.set_cmap("jet")
def define_model(is_resnet, is_densenet, is_senet):
if is_resnet:
original_model = resnet.resnet50(pretrained = True)
Encoder = modules.E_resnet(original_model)
model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])
if is_densenet:
original_model = densenet.densenet161(pretrained=True)
Encoder = modules.E_densenet(original_model)
model = net.model(Encoder, num_features=2208, block_channel = [192, 384, 1056, 2208])
if is_senet:
original_model = senet.senet154(pretrained='imagenet')
Encoder = modules.E_senet(original_model)
model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])
return model
def main():
model = define_model(is_resnet=False, is_densenet=False, is_senet=True)
model = torch.nn.DataParallel(model).cuda()
model.load_state_dict(torch.load('./pretrained_model/model_senet'))
model.eval()
nyu2_loader = loaddata.readNyu2('data/demo/img_nyu2.png')
test(nyu2_loader, model)
def test(nyu2_loader, model):
for i, image in enumerate(nyu2_loader):
image = torch.autograd.Variable(image, volatile=True).cuda()
out = model(image)
matplotlib.image.imsave('data/demo/out.png', out.view(out.size(2),out.size(3)).data.cpu().numpy())
if __name__ == '__main__':
main()