-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels_estimator.py
56 lines (45 loc) · 1.87 KB
/
models_estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
This code and some of the related codes here are inspired from:
https://github.com/Bilkent-CTAR-Lab/DNN-for-Deletion-Channel
"""
import torch
import torch.nn as nn
class BI_Estimator(nn.Module):
def __init__(self, input_size,actual_size, d_rnn=512, d_mlp=[128, 32], num_bi_layers=3):
super(BI_Estimator, self).__init__()
# Parameters init
self.actual_size = actual_size
self.num_bi_layers = num_bi_layers
self.d_rnn = d_rnn
self.d_mlp = d_mlp
self.bir_layers = nn.ModuleList([nn.GRU(input_size if i == 0 else self.d_rnn * 2, self.d_rnn, bidirectional=True, batch_first=True)
for i in range(self.num_bi_layers)])
# Layer normalization layers
self.nor_layers = nn.ModuleList([nn.LayerNorm(self.d_rnn * 2,eps=1e-3) for _ in range(self.num_bi_layers)])
# MLP layers
mlp_layers = []
input_size = self.d_rnn * 2 # Bidirectional output size
for size in self.d_mlp:
mlp_layers.append(nn.Linear(input_size, size))
mlp_layers.append(nn.ReLU())
input_size = size
self.mlp_layers = nn.Sequential(*mlp_layers)
# Output layer
self.output_layer = nn.Linear(input_size, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
"""
Forward Pass of the specified model!
"""
for bir_layer, nor_layer in zip(self.bir_layers, self.nor_layers):
#print(x.shape)
x, _ = bir_layer(x)
#print(x.shape)
x = nor_layer(x)
#print(x.shape)
x = self.mlp_layers(x)
#print(x.shape)
x = self.output_layer(x)
#print(x.shape)
x = self.sigmoid(x)
return x[:,0:self.actual_size,:]