Skip to content

Commit 4968041

Browse files
authored
Add files via upload
1 parent e0f3983 commit 4968041

9 files changed

+995
-0
lines changed

CC_NBDF_Net.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Fri Apr 17 15:23:55 2020
4+
5+
@author: admin
6+
"""
7+
import torch
8+
import torch.nn as nn
9+
10+
11+
class CC_NBDF(nn.Module):
12+
def __init__(feature_map_num = 64, hidden1_dim= 256, hidden2_dim = 128, num_direction = 2, num_layers = 1 ,biFlag = True):
13+
14+
super(CRNN_NBDFNet,self).__init__()
15+
16+
self.feature_map_num = feature_map_num
17+
18+
self.hidden1_dim=hidden1_dim
19+
self.hidden2_dim=hidden2_dim
20+
21+
self.num_layers= num_layers
22+
self.num_direction = num_direction
23+
24+
self.output1_dim=self.hidden1_dim*self.num_direction
25+
self.output2_dim=self.hidden2_dim*self.num_direction
26+
27+
28+
29+
self.biFlag=biFlag
30+
31+
self.BN = nn.BatchNorm2d(self.feature_map_num)
32+
33+
self.relu = torch.nn.ReLU()
34+
35+
self.cnn1 = nn.Conv2d(2,self.feature_map_num,(2,1))
36+
37+
self.cnn2 = nn.Conv2d(self.feature_map_num,self.feature_map_num,(2,1))
38+
39+
self.rnn1 = nn.LSTM(input_size=self.feature_map_num, hidden_size = self.hidden1_dim, \
40+
num_layers=self.num_layers,batch_first=True, \
41+
bidirectional=biFlag)
42+
43+
self.rnn2 = nn.LSTM(input_size=self.output1_dim,hidden_size = self.hidden2_dim, \
44+
num_layers=self.num_layers,batch_first=True, \
45+
bidirectional=biFlag)
46+
47+
48+
self.linearTimeDistributed = nn.Linear(self.output2_dim, 1)
49+
50+
def forward(self,inputsignal):
51+
52+
53+
cnn1out = self.relu(self.cnn1(inputsignal)) # (512,2,4,192) -> (512,64,3,192)
54+
55+
56+
while cnn1out.shape[2] != 1: # recursive convolution (512,64,3,192) -> (512,64,192)
57+
cnn1out = self.relu(self.cnn2(cnn1out))
58+
59+
cnn2out = torch.squeeze(cnn1out)
60+
cnn2out = torch.transpose(cnn2out,1,2)
61+
62+
rnn1out,_ = self.rnn1(cnn2out)
63+
64+
rnn2out,_ = self.rnn2(rnn1out)
65+
66+
outsignal = torch.sigmoid(self.linearTimeDistributed(rnn2out)) # linear1out.dim = (1024, 192, 8) -> (1024, 192, 1)
67+
68+
69+
70+
return outsignal

NBDF_Net.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Fri Apr 17 15:23:55 2020
4+
5+
@author: admin
6+
"""
7+
import torch
8+
import torch.nn as nn
9+
10+
11+
12+
13+
14+
class NBDF_Net(nn.Module):
15+
def __init__(self, n_chanels, hidden1_dim= 256, hidden2_dim = 128, num_direction = 2, num_layers = 1 ,biFlag = True):
16+
17+
super(NBDFNet,self).__init__()
18+
19+
self.n_chanels=n_chanels
20+
21+
self.input_dim= 2*self.n_chanels
22+
23+
self.hidden1_dim=hidden1_dim
24+
self.hidden2_dim=hidden2_dim
25+
26+
self.output1_dim=self.hidden1_dim*num_direction
27+
self.output2_dim=self.hidden2_dim*num_direction
28+
29+
self.num_layers= num_layers
30+
31+
self.target = target
32+
self.biFlag=biFlag
33+
34+
35+
self.rnn1 = nn.LSTM(input_size=self.input_dim, hidden_size = self.hidden1_dim, \
36+
num_layers=self.num_layers,batch_first=True, \
37+
bidirectional=biFlag)
38+
39+
self.rnn2 = nn.LSTM(input_size=self.output1_dim,hidden_size = self.hidden2_dim, \
40+
num_layers=self.num_layers,batch_first=True, \
41+
bidirectional=biFlag)
42+
43+
44+
self.linearTimeDistributed = nn.Linear(self.output2_dim, 1)
45+
46+
47+
def forward(self,inputsignal):
48+
49+
50+
rnn1out,_ = self.rnn1(inputsignal)
51+
52+
rnn2out,_ = self.rnn2(rnn1out)
53+
54+
outsignal = torch.sigmoid(self.linearTimeDistributed(rnn2out)) # linear1out.dim = (1024, 192, 512) -> (1024, 192, 1)
55+
56+
57+
58+
return outsignal

NB_Dataset.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Fri Apr 17 15:13:05 2020
4+
5+
@author: admin
6+
"""
7+
8+
import torch
9+
import torch.nn as nn
10+
# import torchvision
11+
# import torch.nn.functional as F
12+
from torch import optim
13+
# import torchvision.transforms as transforms
14+
from torch.utils.data import DataLoader, Dataset
15+
16+
17+
import os,fnmatch
18+
import numpy as np
19+
20+
21+
class NB_Dataset(Dataset):
22+
23+
def __init__(self, data_path, batchsize = 512, time_steps = 192, shuffle=True):
24+
self.data_path = data_path
25+
self.time_steps = time_steps
26+
self.batchsize = batchsize
27+
self.shuffle = shuffle
28+
self.on_epoch_end()
29+
30+
31+
def __getitem__(self, index):
32+
33+
batchname = fnmatch.filter(os.listdir(self.data_path),'batch{}*'.format(self.indexes[index]))[0]
34+
sample = np.load(self.data_path+ batchname)
35+
X = sample['X'][:self.batchsize,:self.time_steps,:].astype('float32')
36+
y = sample['mrm'][:self.batchsize,:self.time_steps].reshape(self.batchsize,self.time_steps,1).astype('float32')
37+
38+
39+
X = torch.from_numpy(X)
40+
y = torch.from_numpy(y)
41+
42+
return X, y
43+
44+
def __len__(self):
45+
46+
return len(fnmatch.filter(os.listdir(self.data_path),'batch*.npz'))
47+
48+
def on_epoch_end(self):
49+
# 'Updates indexes after each epoch'
50+
self.indexes = np.arange(self.__len__())
51+
if self.shuffle == True:
52+
np.random.shuffle(self.indexes)
53+
54+
55+
"""
56+
if __name__ == "__main__":
57+
58+
train_path = '../Array_position/train_val_batch/train_batch/'
59+
val_path = '../Array_position/train_val_batch/validation_batch/'
60+
61+
# wavFiles = fnmatch.filter(os.listdir(train_path),'batch10606*.npz')
62+
63+
64+
btz = 1
65+
66+
train_NBDataset = NBDataset(data_path = train_path, time_steps=192, shuffle = True)
67+
val_NBDataset = NBDataset(data_path = val_path, time_steps=192, shuffle = True)
68+
69+
70+
train_DataLoader = DataLoader(
71+
dataset=train_NBDataset, # torch TensorDataset format
72+
batch_size=btz, # mini batch size
73+
shuffle=True, # random shuffle for training
74+
drop_last=True,
75+
num_workers=0, # subprocesses for loading data
76+
)
77+
78+
val_DataLoader = DataLoader(
79+
dataset=val_NBDataset, # torch TensorDataset format
80+
batch_size=btz, # mini batch size
81+
shuffle=True, # random shuffle for training
82+
drop_last=True,
83+
num_workers=0, # subprocesses for loading data
84+
)
85+
86+
n,m = train_NBDataset[2]
87+
print(n.shape,m.shape)
88+
89+
for a,b in train_DataLoader:
90+
print(a.shape,b.shape)
91+
"""

PW_NBDF_Net.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Fri Apr 17 15:23:55 2020
4+
5+
@author: admin
6+
"""
7+
import torch
8+
import torch.nn as nn
9+
from tqdm import tqdm
10+
11+
12+
13+
14+
class PW_NBDF(nn.Module):
15+
def __init__(self, input_dim = 4, hidden1_dim = 256, hidden2_dim = 128, num_direction = 2, device = 'cuda', num_layers = 1, biFlag = True):
16+
17+
super(PW_NBDF,self).__init__()
18+
19+
self.input_dim=input_dim
20+
21+
self.hidden1_dim=hidden1_dim
22+
self.hidden2_dim=hidden2_dim
23+
24+
self.num_direction = num_direction
25+
26+
self.output1_dim=self.hidden1_dim*num_direction
27+
self.output2_dim=self.hidden2_dim*num_direction
28+
29+
self.num_layers= num_layers
30+
self.device = device
31+
32+
33+
self.biFlag=biFlag
34+
35+
36+
self.rnn1 = nn.LSTM(input_size=self.input_dim, hidden_size = self.hidden1_dim, \
37+
num_layers=self.num_layers,batch_first=True, \
38+
bidirectional=self.biFlag)
39+
40+
self.rnn2 = nn.LSTM(input_size=self.output1_dim,hidden_size = self.hidden2_dim, \
41+
num_layers=self.num_layers,batch_first=True, \
42+
bidirectional=self.biFlag)
43+
44+
self.linearTimeDistributed = nn.Linear(self.output2_dim, 1)
45+
46+
47+
48+
49+
def forward(self,inputsignal):
50+
51+
B,T,C = inputsignal.shape # (B,T,C)
52+
53+
n_pairs = C//2-1 # number of channel of pairs
54+
x = torch.zeros(B, n_pairs, T, 4, device = self.device) # (B,N,T,4)
55+
56+
for i in range(n_pairs):
57+
x[:, i, :, :2] = inputsignal[:,:,:2]
58+
x[:, i, :, 2:] = inputsignal[:,:,(i+1)*2:(i+2)*2]
59+
x = x.view(B*n_pairs, T, 4) # (B*N , T, 4)
60+
rnn1out, _ = self.rnn1(x)
61+
rnn1out = rnn1out.view(B, n_pairs, T, self.output1_dim) # (B, N, T, Dim1*2)
62+
rnn1out_combined = torch.mean(rnn1out,dim = 1) # (B, T, Dim1*2)
63+
64+
rnn2out,_ = self.rnn2(rnn1out_combined) # (B, T, Dim2*2)
65+
outsignal = torch.sigmoid(self.linearTimeDistributed(rnn2out)).squeeze() # (B, T)
66+
67+
68+
69+
return outsignal # 1D mask output

0 commit comments

Comments
 (0)