-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmodel.py
121 lines (98 loc) · 3.53 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
import math
import pdb
from collections import OrderedDict
class QuadNet(nn.Module):
def __init__(self, num_classes=10):
super(QuadNet, self).__init__()
self.featReal = nn.Sequential(
nn.Conv2d(3, 100, kernel_size=7, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(2,2),
nn.Conv2d(100, 150, kernel_size=4, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(2,2),
nn.Conv2d(150, 250, kernel_size=4, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(2,2),
)
self.fcReal = nn.Sequential(
nn.Linear(3*3*250, 300),
nn.ReLU(inplace=True),
nn.Linear(300, num_classes),
)
self.featTemp = nn.Sequential(
nn.Conv2d(3, 100, kernel_size=7, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(2,2),
nn.Conv2d(100, 150, kernel_size=4, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(2,2),
nn.Conv2d(150, 250, kernel_size=4, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(2,2),
)
self.fcTemp = nn.Sequential(
nn.Linear(3*3*250, 300),
nn.ReLU(inplace=True),
nn.Linear(300, num_classes),
)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
def forward_real(self, x):
x = self.featReal(x)
x = x.view(-1, 3*3*250)
x = self.fcReal(x)
return x
def forward_temp(self, x):
x = self.featTemp(x)
x = x.view(-1, 3*3*250)
x = self.fcTemp(x)
return x
def forward(self, realA, realB, tempA, tempB):
RA = self.forward_real(realA)
RB = self.forward_real(realB)
TA = self.forward_temp(tempA)
TB = self.forward_temp(tempB)
return RA, RB, TA, TB
class QuadNetSingle(nn.Module):
def __init__(self, num_classes=10):
super(QuadNetSingle, self).__init__()
self.conv1 = nn.Conv2d(3, 100, kernel_size=7, padding=0)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(100, 150, kernel_size=4, padding=0)
self.conv3 = nn.Conv2d(150, 250, kernel_size=4, padding=0)
self.fc1 = nn.Linear(3*3*250, 300)
self.fc2 = nn.Linear(300, num_classes)
self.relu = nn.ReLU(inplace=True)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
def forward_once(self, x):
x = self.conv1(x) # 100x42x42
x = self.relu(x) # 100x42x42
x1 = self.pool(x) # 100x21x21
x = self.conv2(x1) # 150x18x18
x = self.relu(x) # 150x18x18
x2 = self.pool(x) # 150x9x9
x = self.conv3(x2) # 250x6x6
x = self.relu(x) # 250x6x6
x3 = self.pool(x) # 250x3x3
xv = x3.view(-1, 3*3*250)
xfc1 = self.relu(self.fc1(xv))
output = self.fc2(xfc1)
return output
def forward(self, realA, realB, tempA, tempB):
RA = self.forward_once(realA)
RB = self.forward_once(realB)
TA = self.forward_once(tempA)
TB = self.forward_once(tempB)
return RA, RB, TA, TB