forked from dsindex/ntagger
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
709 lines (624 loc) · 30.7 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
from __future__ import absolute_import, division, print_function
import os
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from torchcrf import CRF
class BaseModel(nn.Module):
def __init__(self, config=None):
super(BaseModel, self).__init__()
if config and hasattr(config['opt'], 'seed'):
self.set_seed(config['opt'])
def set_seed(self, opt):
random.seed(opt.seed)
np.random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
def load_embedding(self, input_path):
weights_matrix = np.load(input_path)
weights_matrix = torch.as_tensor(weights_matrix)
return weights_matrix
def create_embedding_layer(self, vocab_dim, emb_dim, weights_matrix=None, non_trainable=True, padding_idx=0):
emb_layer = nn.Embedding(vocab_dim, emb_dim, padding_idx=padding_idx)
if torch.is_tensor(weights_matrix):
emb_layer.load_state_dict({'weight': weights_matrix})
if non_trainable:
emb_layer.weight.requires_grad = False
return emb_layer
def load_dict(self, input_path):
dic = {}
with open(input_path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
toks = line.strip().split()
_key = toks[0]
_id = int(toks[1])
dic[_id] = _key
return dic
def forward(self, x):
return x
class TextCNN(nn.Module):
def __init__(self, in_channels, out_channels, kernel_sizes):
super(TextCNN, self).__init__()
convs = []
for ks in kernel_sizes:
convs.append(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=ks))
self.convs = nn.ModuleList(convs)
def forward(self, x):
# x : [batch_size, seq_size, emb_dim]
# num_filters == out_channels
x = x.permute(0, 2, 1)
# x : [batch_size, emb_dim, seq_size]
conved = [F.relu(conv(x)) for conv in self.convs]
# conved : [ [batch_size, num_filters, *], [batch_size, num_filters, *], [batch_size, num_filters, *] ]
# for ONNX conversion, do not use F.max_pool1d(),
pooled = [torch.max(cv, dim=2)[0] for cv in conved]
# pooled : [ [batch_size, num_filters], [batch_size, num_filters], [batch_size, num_filters] ]
cat = torch.cat(pooled, dim = 1)
# cat : [batch_size, len(kernel_sizes) * num_filters]
return cat
class DenseNet(nn.Module):
def __init__(self, densenet_kernels, emb_dim, first_num_filters, num_filters, last_num_filters, activation=F.relu):
super(DenseNet, self).__init__()
self.activation = activation
self.densenet_kernels = densenet_kernels
self.densenet_width = len(densenet_kernels[0])
self.densenet_block = []
for i, kss in enumerate(self.densenet_kernels): # densenet depth
if i == 0:
in_channels = emb_dim
out_channels = first_num_filters
else:
in_channels = first_num_filters + num_filters * (i-1)
out_channels = num_filters
convs = []
for j, ks in enumerate(kss): # densenet width
padding = (ks - 1)//2
conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=ks, padding=padding)
convs.append(conv)
convs = nn.ModuleList(convs)
self.densenet_block.append(convs)
self.densenet_block = nn.ModuleList(self.densenet_block)
ks = 1
in_channels = emb_dim + num_filters * self.densenet_width
out_channels = last_num_filters
padding = (ks - 1)//2
self.conv_last = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=ks, padding=padding)
self.last_dim = last_num_filters
def forward(self, x, mask):
# x : [batch_size, seq_size, emb_dim]
# mask : [batch_size, seq_size]
x = x.permute(0, 2, 1)
# x : [batch_size, emb_dim, seq_size]
masks = mask.unsqueeze(2).to(torch.float)
# masks : [batch_size, seq_size, 1]
masks = masks.permute(0, 2, 1)
# masks : [batch_size, 1, seq_size]
merge_list = []
for j in range(self.densenet_width):
conv_results = []
for i, kss in enumerate(self.densenet_kernels):
if i == 0: conv_in = x
else: conv_in = torch.cat(conv_results, dim=-2)
conv_out = self.densenet_block[i][j](conv_in)
# conv_out first : [batch_size, first_num_filters, seq_size]
# conv_out other : [batch_size, num_filters, seq_size]
conv_out *= masks # masking, auto broadcasting along with second dimension
conv_out = self.activation(conv_out)
conv_results.append(conv_out)
merge_list.append(conv_results[-1]) # last one only
conv_last = self.conv_last(torch.cat([x] + merge_list, dim=-2))
conv_last *= masks
conv_last = F.relu(conv_last)
# conv_last : [batch_size, last_num_filters, seq_size]
conv_last = conv_last.permute(0, 2, 1)
# conv_last : [batch_size, seq_size, last_num_filters]
return conv_last
class DSA(nn.Module):
def __init__(self, config, dsa_num_attentions, dsa_input_dim, dsa_dim, dsa_r=3):
super(DSA, self).__init__()
self.config = config
self.device = config['opt'].device
dsa = []
for i in range(dsa_num_attentions):
dsa.append(nn.Linear(dsa_input_dim, dsa_dim))
self.dsa = nn.ModuleList(dsa)
self.dsa_r = dsa_r # r iterations
self.last_dim = dsa_num_attentions * dsa_dim
def __self_attention(self, x, mask, r=3):
# x : [batch_size, seq_size, dsa_dim]
# mask : [batch_size, seq_size]
# r : r iterations
# initialize
mask = mask.to(torch.float)
inv_mask = mask.eq(0.0)
# inv_mask : [batch_size, seq_size], ex) [False, ..., False, True, ..., True]
softmax_mask = mask.masked_fill(inv_mask, -1e20)
# softmax_mask : [batch_size, seq_size], ex) [1., 1., 1., ..., -1e20, -1e20, -1e20]
q = torch.zeros(mask.shape[0], mask.shape[-1], requires_grad=False).to(torch.float).to(self.device)
# q : [batch_size, seq_size]
z_list = []
# iterative computing attention
for idx in range(r):
# softmax masking
q *= softmax_mask
# attention weights
a = torch.softmax(q.detach().clone(), dim=-1) # preventing from unreachable variable at gradient computation.
# a : [batch_size, seq_size]
a *= mask
a = a.unsqueeze(2)
# a : [batch_size, seq_size, 1]
# element-wise multiplication(broadcasting) and summation along 1 dim
s = (a * x).sum(1)
# s : [batch_size, dsa_dim]
z = torch.tanh(s)
# z : [batch_size, dsa_dim]
z_list.append(z)
# update q
m = z.unsqueeze(2)
# m : [batch_size, dsa_dim, 1]
q += torch.matmul(x, m).squeeze(2)
# q : [batch_size, seq_size]
return z_list[-1]
def forward(self, x, mask):
# x : [batch_size, seq_size, dsa_input_dim]
# mak : [batch_size, seq_size]
z_list = []
for p in self.dsa: # dsa_num_attentions
# projection to dsa_dim
p_out = F.leaky_relu(p(x))
# p_out : [batch_size, seq_size, dsa_dim]
z_j = self.__self_attention(p_out, mask, r=self.dsa_r)
# z_j : [batch_size, dsa_dim]
z_list.append(z_j)
z = torch.cat(z_list, dim=-1)
# z : [batch_size, dsa_num_attentions * dsa_dim]
return z
class CharCNN(BaseModel):
def __init__(self, config):
super().__init__(config=config)
self.config = config
self.device = config['opt'].device
self.seq_size = config['n_ctx']
self.char_n_ctx = config['char_n_ctx']
char_vocab_size = config['char_vocab_size']
self.char_emb_dim = config['char_emb_dim']
char_num_filters = config['char_num_filters']
char_kernel_sizes = config['char_kernel_sizes']
self.char_padding_idx = config['char_padding_idx']
self.embed_char = super().create_embedding_layer(char_vocab_size, self.char_emb_dim, weights_matrix=None, non_trainable=False, padding_idx=self.char_padding_idx)
self.textcnn = TextCNN(self.char_emb_dim, char_num_filters, char_kernel_sizes)
self.last_dim = len(char_kernel_sizes) * char_num_filters
def forward(self, x):
# x : [batch_size, seq_size, char_n_ctx]
char_ids = x
# char_ids : [batch_size, seq_size, char_n_ctx]
mask = char_ids.view(-1, self.char_n_ctx).ne(self.char_padding_idx) # broadcasting
# mask : [batch_size*seq_size, char_n_ctx]
mask = mask.unsqueeze(2).to(torch.float)
# mask : [batch_size*seq_size, char_n_ctx, 1]
char_embed_out = self.embed_char(char_ids)
# char_embed_out : [batch_size, seq_size, char_n_ctx, char_emb_dim]
char_embed_out = char_embed_out.view(-1, self.char_n_ctx, self.char_emb_dim)
# char_embed_out : [batch_size*seq_size, char_n_ctx, char_emb_dim]
char_embed_out *= mask # masking, auto-broadcasting
charcnn_out = self.textcnn(char_embed_out)
# charcnn_out : [batch_size*seq_size, last_dim]
charcnn_out = charcnn_out.view(-1, self.seq_size, charcnn_out.shape[-1])
# charcnn_out : [batch_size, seq_size, last_dim]
return charcnn_out
class GloveLSTMCRF(BaseModel):
def __init__(self, config, embedding_path, label_path, pos_path, emb_non_trainable=True, use_crf=False, use_char_cnn=False):
super().__init__(config=config)
self.config = config
self.device = config['opt'].device
self.seq_size = config['n_ctx']
pos_emb_dim = config['pos_emb_dim']
lstm_hidden_dim = config['lstm_hidden_dim']
lstm_num_layers = config['lstm_num_layers']
lstm_dropout = config['lstm_dropout']
self.use_char_cnn = use_char_cnn
self.use_crf = use_crf
# glove embedding layer
weights_matrix = super().load_embedding(embedding_path)
vocab_dim, token_emb_dim = weights_matrix.size()
padding_idx = config['pad_token_id']
self.embed_token = super().create_embedding_layer(vocab_dim, token_emb_dim, weights_matrix=weights_matrix, non_trainable=emb_non_trainable, padding_idx=padding_idx)
# pos embedding layer
self.poss = super().load_dict(pos_path)
self.pos_vocab_size = len(self.poss)
padding_idx = config['pad_pos_id']
self.embed_pos = super().create_embedding_layer(self.pos_vocab_size, pos_emb_dim, weights_matrix=None, non_trainable=False, padding_idx=padding_idx)
emb_dim = token_emb_dim + pos_emb_dim
# char embedding layer
if self.use_char_cnn:
self.charcnn = CharCNN(config)
emb_dim = token_emb_dim + pos_emb_dim + self.charcnn.last_dim
# BiLSTM layer
self.lstm = nn.LSTM(input_size=emb_dim,
hidden_size=lstm_hidden_dim,
num_layers=lstm_num_layers,
dropout=lstm_dropout,
bidirectional=True,
batch_first=True)
self.dropout = nn.Dropout(config['dropout'])
# projection layer
self.labels = super().load_dict(label_path)
self.label_size = len(self.labels)
self.linear = nn.Linear(lstm_hidden_dim*2, self.label_size)
# CRF layer
if self.use_crf:
self.crf = CRF(num_tags=self.label_size, batch_first=True)
def forward(self, x):
# x[0, 1] : [batch_size, seq_size]
# x[2] : [batch_size, seq_size, char_n_ctx]
token_ids = x[0]
pos_ids = x[1]
mask = torch.sign(torch.abs(token_ids)).to(torch.uint8).to(self.device)
# mask : [batch_size, seq_size]
lengths = torch.sum(mask.to(torch.long), dim=1)
# lengths : [batch_size]
# 1. Embedding
token_embed_out = self.embed_token(token_ids)
# token_embed_out : [batch_size, seq_size, token_emb_dim]
pos_embed_out = self.embed_pos(pos_ids)
# pos_embed_out : [batch_size, seq_size, pos_emb_dim]
if self.use_char_cnn:
char_ids = x[2]
# char_ids : [batch_size, seq_size, char_n_ctx]
charcnn_out = self.charcnn(char_ids)
# charcnn_out : [batch_size, seq_size, self.charcnn.last_dim]
embed_out = torch.cat([token_embed_out, pos_embed_out, charcnn_out], dim=-1)
# embed_out : [batch_size, seq_size, emb_dim]
else:
embed_out = torch.cat([token_embed_out, pos_embed_out], dim=-1)
# embed_out : [batch_size, seq_size, emb_dim]
embed_out = self.dropout(embed_out)
# 2. LSTM
packed_embed_out = torch.nn.utils.rnn.pack_padded_sequence(embed_out, lengths, batch_first=True, enforce_sorted=False)
lstm_out, (h_n, c_n) = self.lstm(packed_embed_out)
lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True, total_length=self.seq_size)
# lstm_out : [batch_size, seq_size, lstm_hidden_dim*2]
lstm_out = self.dropout(lstm_out)
# 3. Output
logits = self.linear(lstm_out)
# logits : [batch_size, seq_size, label_size]
if not self.use_crf: return logits
prediction = self.crf.decode(logits)
prediction = torch.as_tensor(prediction, dtype=torch.long)
# prediction : [batch_size, seq_size]
return logits, prediction
class GloveDensenetCRF(BaseModel):
def __init__(self, config, embedding_path, label_path, pos_path, emb_non_trainable=True, use_crf=False, use_char_cnn=False):
super().__init__(config=config)
self.config = config
self.device = config['opt'].device
self.seq_size = config['n_ctx']
pos_emb_dim = config['pos_emb_dim']
self.use_crf = use_crf
self.use_char_cnn = use_char_cnn
# glove embedding layer
weights_matrix = super().load_embedding(embedding_path)
vocab_dim, token_emb_dim = weights_matrix.size()
padding_idx = config['pad_token_id']
self.embed_token = super().create_embedding_layer(vocab_dim, token_emb_dim, weights_matrix=weights_matrix, non_trainable=emb_non_trainable, padding_idx=padding_idx)
# pos embedding layer
self.poss = super().load_dict(pos_path)
self.pos_vocab_size = len(self.poss)
padding_idx = config['pad_pos_id']
self.embed_pos = super().create_embedding_layer(self.pos_vocab_size, pos_emb_dim, weights_matrix=None, non_trainable=False, padding_idx=padding_idx)
emb_dim = token_emb_dim + pos_emb_dim
# char embedding layer
if self.use_char_cnn:
self.charcnn = CharCNN(config)
emb_dim = token_emb_dim + pos_emb_dim + self.charcnn.last_dim
# Densenet layer
densenet_kernels = config['densenet_kernels']
first_num_filters = config['densenet_first_num_filters']
num_filters = config['densenet_num_filters']
last_num_filters = config['densenet_last_num_filters']
self.densenet = DenseNet(densenet_kernels, emb_dim, first_num_filters, num_filters, last_num_filters, activation=F.relu)
self.layernorm_densenet = nn.LayerNorm(self.densenet.last_dim)
self.dropout = nn.Dropout(config['dropout'])
# projection layer
self.labels = super().load_dict(label_path)
self.label_size = len(self.labels)
self.linear = nn.Linear(last_num_filters, self.label_size)
# CRF layer
if self.use_crf:
self.crf = CRF(num_tags=self.label_size, batch_first=True)
def forward(self, x):
# x[0, 1] : [batch_size, seq_size]
# x[2] : [batch_size, seq_size, char_n_ctx]
token_ids = x[0]
pos_ids = x[1]
mask = torch.sign(torch.abs(token_ids)).to(torch.uint8).to(self.device)
# mask : [batch_size, seq_size]
# 1. Embedding
token_embed_out = self.embed_token(token_ids)
# token_embed_out : [batch_size, seq_size, token_emb_dim]
pos_embed_out = self.embed_pos(pos_ids)
# pos_embed_out : [batch_size, seq_size, pos_emb_dim]
if self.use_char_cnn:
char_ids = x[2]
# char_ids : [batch_size, seq_size, char_n_ctx]
charcnn_out = self.charcnn(char_ids)
# charcnn_out : [batch_size, seq_size, self.charcnn.last_dim]
embed_out = torch.cat([token_embed_out, pos_embed_out, charcnn_out], dim=-1)
# embed_out : [batch_size, seq_size, emb_dim]
else:
embed_out = torch.cat([token_embed_out, pos_embed_out], dim=-1)
# embed_out : [batch_size, seq_size, emb_dim]
embed_out = self.dropout(embed_out)
# 2. DenseNet
densenet_out = self.densenet(embed_out, mask)
# densenet_out : [batch_size, seq_size, last_num_filters]
densenet_out = self.layernorm_densenet(densenet_out)
densenet_out = self.dropout(densenet_out)
# 3. Output
logits = self.linear(densenet_out)
# logits : [batch_size, seq_size, label_size]
if not self.use_crf: return logits
prediction = self.crf.decode(logits)
prediction = torch.as_tensor(prediction, dtype=torch.long)
# prediction : [batch_size, seq_size]
return logits, prediction
class BertLSTMCRF(BaseModel):
def __init__(self, config, bert_config, bert_model, bert_tokenizer, label_path, pos_path, use_crf=False, use_pos=False, disable_lstm=False, feature_based=False):
super().__init__(config=config)
self.config = config
self.device = config['opt'].device
self.seq_size = config['n_ctx']
pos_emb_dim = config['pos_emb_dim']
lstm_hidden_dim = config['lstm_hidden_dim']
lstm_num_layers = config['lstm_num_layers']
lstm_dropout = config['lstm_dropout']
self.use_crf = use_crf
self.use_pos = use_pos
self.disable_lstm = disable_lstm
# bert embedding layer
self.bert_config = bert_config
self.bert_model = bert_model
self.bert_tokenizer = bert_tokenizer
self.bert_feature_based = feature_based
self.bert_hidden_size = bert_config.hidden_size
self.bert_num_layers = bert_config.num_hidden_layers
# DSA layer for bert_feature_based
dsa_num_attentions = config['dsa_num_attentions']
dsa_input_dim = self.bert_hidden_size
dsa_dim = config['dsa_dim']
dsa_r = config['dsa_r']
self.dsa = DSA(config, dsa_num_attentions, dsa_input_dim, dsa_dim, dsa_r=dsa_r)
self.layernorm_dsa = nn.LayerNorm(self.dsa.last_dim)
bert_emb_dim = self.bert_hidden_size
if self.bert_feature_based:
'''
# 1) last layer, 2) mean pooling
bert_emb_dim = self.bert_hidden_size
'''
# 3) DSA pooling
bert_emb_dim = self.dsa.last_dim
# pos embedding layer
self.poss = super().load_dict(pos_path)
self.pos_vocab_size = len(self.poss)
padding_idx = config['pad_pos_id']
self.embed_pos = super().create_embedding_layer(self.pos_vocab_size, pos_emb_dim, weights_matrix=None, non_trainable=False, padding_idx=padding_idx)
# BiLSTM layer
if self.use_pos:
emb_dim = bert_emb_dim + pos_emb_dim
else:
emb_dim = bert_emb_dim
if not self.disable_lstm:
self.lstm = nn.LSTM(input_size=emb_dim,
hidden_size=lstm_hidden_dim,
num_layers=lstm_num_layers,
dropout=lstm_dropout,
bidirectional=True,
batch_first=True)
self.dropout = nn.Dropout(config['dropout'])
# projection layer
self.labels = super().load_dict(label_path)
self.label_size = len(self.labels)
if not self.disable_lstm:
self.linear = nn.Linear(lstm_hidden_dim*2, self.label_size)
else:
self.linear = nn.Linear(emb_dim, self.label_size)
# CRF layer
if self.use_crf:
self.crf = CRF(num_tags=self.label_size, batch_first=True)
def _compute_bert_embedding(self, x):
if self.bert_feature_based:
# feature-based
with torch.no_grad():
if self.config['emb_class'] in ['bart', 'distilbert']:
bert_outputs = self.bert_model(input_ids=x[0],
attention_mask=x[1])
# bart model's output(output_hidden_states == True)
# [0] last decoder layer's output : [batch_size, seq_size, bert_hidden_size]
# [1] all hidden states of decoder layer's
# [2] last encoder layer's output : [seq_size, batch_size, bert_hidden_size]
# [3] all hidden states of encoder layer's
all_hidden_states = bert_outputs[1][0:]
elif 'electra' in self.config['emb_class']:
bert_outputs = self.bert_model(input_ids=x[0],
attention_mask=x[1],
token_type_ids=x[2])
# electra model's output
# list of each layer's hidden states
all_hidden_states = bert_outputs
else:
bert_outputs = self.bert_model(input_ids=x[0],
attention_mask=x[1],
token_type_ids=None if self.config['emb_class'] in ['roberta'] else x[2]) # RoBERTa don't use segment_ids
all_hidden_states = bert_outputs[2][0:]
# last hidden states, pooled output, initial embedding layer, 1 ~ last layer's hidden states
# bert_outputs[0], bert_outputs[1], bert_outputs[2][0], bert_outputs[2][1:]
'''
# 1) last layer
embedded = bert_outputs[0]
# embedded : [batch_size, seq_size, bert_hidden_size]
'''
'''
# 2) mean pooling
stack = torch.stack(all_hidden_states, dim=-1)
embedded = torch.mean(stack, dim=-1)
# ([batch_size, seq_size, bert_hidden_size], ..., [batch_size, seq_size, bert_hidden_size])
# -> stack(-1) -> [batch_size, seq_size, bert_hidden_size, *], ex) * == 25 for bert large
# -> max/mean(-1) -> [batch_size, seq_size, bert_hidden_size]
'''
# 3) DSA pooling
stack = torch.stack(all_hidden_states, dim=-2)
# stack : [batch_size, seq_size, *, bert_hidden_size]
stack = stack.view(-1, self.bert_num_layers + 1, self.bert_hidden_size)
# stack : [*, bert_num_layers, bert_hidden_size]
dsa_mask = torch.ones(stack.shape[0], stack.shape[1]).to(self.device)
# dsa_mask : [*, bert_num_layers]
dsa_out = self.dsa(stack, dsa_mask)
# dsa_out : [*, self.dsa.last_dim]
dsa_out = self.layernorm_dsa(dsa_out)
embedded = dsa_out.view(-1, self.seq_size, self.dsa.last_dim)
# embedded : [batch_size, seq_size, self.dsa.last_dim]
else:
# fine-tuning
# x[0], x[1], x[2] : [batch_size, seq_size]
if self.config['emb_class'] in ['bart', 'distilbert']:
bert_outputs = self.bert_model(input_ids=x[0],
attention_mask=x[1])
embedded = bert_outputs[0]
else:
bert_outputs = self.bert_model(input_ids=x[0],
attention_mask=x[1],
token_type_ids=None if self.config['emb_class'] in ['roberta'] else x[2]) # RoBERTa don't use segment_ids
embedded = bert_outputs[0]
# embedded : [batch_size, seq_size, bert_hidden_size]
return embedded
def forward(self, x):
# x[0,1,2] : [batch_size, seq_size]
mask = x[1].to(torch.uint8).to(self.device)
# mask == attention_mask : [batch_size, seq_size]
lengths = torch.sum(mask.to(torch.long), dim=1)
# lengths : [batch_size]
# 1. Embedding
bert_embed_out = self._compute_bert_embedding(x)
# bert_embed_out : [batch_size, seq_size, *]
pos_ids = x[3]
pos_embed_out = self.embed_pos(pos_ids)
# pos_embed_out : [batch_size, seq_size, pos_emb_dim]
if self.use_pos:
embed_out = torch.cat([bert_embed_out, pos_embed_out], dim=-1)
else:
embed_out = bert_embed_out
# embed_out : [batch_size, seq_size, emb_dim]
embed_out = self.dropout(embed_out)
# 2. LSTM
if not self.disable_lstm:
packed_embed_out = torch.nn.utils.rnn.pack_padded_sequence(embed_out, lengths, batch_first=True, enforce_sorted=False)
lstm_out, (h_n, c_n) = self.lstm(packed_embed_out)
lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True, total_length=self.seq_size)
# lstm_out : [batch_size, seq_size, lstm_hidden_dim*2]
lstm_out = self.dropout(lstm_out)
else:
lstm_out = embed_out
# lstm_out : [batch_size, seq_size, emb_dim]
# 3. Output
logits = self.linear(lstm_out)
# logits : [batch_size, seq_size, label_size]
if not self.use_crf: return logits
prediction = self.crf.decode(logits)
prediction = torch.as_tensor(prediction, dtype=torch.long)
# prediction : [batch_size, seq_size]
return logits, prediction
class ElmoLSTMCRF(BaseModel):
def __init__(self, config, elmo_model, embedding_path, label_path, pos_path, emb_non_trainable=True, use_crf=False, use_char_cnn=False):
super().__init__(config=config)
self.config = config
self.device = config['opt'].device
self.seq_size = config['n_ctx']
pos_emb_dim = config['pos_emb_dim']
elmo_emb_dim = config['elmo_emb_dim']
lstm_hidden_dim = config['lstm_hidden_dim']
lstm_num_layers = config['lstm_num_layers']
lstm_dropout = config['lstm_dropout']
self.use_crf = use_crf
self.use_char_cnn = use_char_cnn
# elmo embedding
self.elmo_model = elmo_model
# glove embedding layer
weights_matrix = super().load_embedding(embedding_path)
vocab_dim, token_emb_dim = weights_matrix.size()
padding_idx = config['pad_token_id']
self.embed_token = super().create_embedding_layer(vocab_dim, token_emb_dim, weights_matrix=weights_matrix, non_trainable=emb_non_trainable, padding_idx=padding_idx)
# pos embedding layer
self.poss = super().load_dict(pos_path)
self.pos_vocab_size = len(self.poss)
padding_idx = config['pad_pos_id']
self.embed_pos = super().create_embedding_layer(self.pos_vocab_size, pos_emb_dim, weights_matrix=None, non_trainable=False, padding_idx=padding_idx)
emb_dim = elmo_emb_dim + token_emb_dim + pos_emb_dim
# char embedding layer
if self.use_char_cnn:
self.charcnn = CharCNN(config)
emb_dim = elmo_emb_dim + token_emb_dim + pos_emb_dim + self.charcnn.last_dim
# BiLSTM layer
self.lstm = nn.LSTM(input_size=emb_dim,
hidden_size=lstm_hidden_dim,
num_layers=lstm_num_layers,
dropout=lstm_dropout,
bidirectional=True,
batch_first=True)
self.dropout = nn.Dropout(config['dropout'])
# projection layer
self.labels = super().load_dict(label_path)
self.label_size = len(self.labels)
self.linear = nn.Linear(lstm_hidden_dim*2, self.label_size)
# CRF layer
if self.use_crf:
self.crf = CRF(num_tags=self.label_size, batch_first=True)
def forward(self, x):
# x[0,1] : [batch_size, seq_size]
# x[2] : [batch_size, seq_size, max_characters_per_token]
token_ids = x[0]
pos_ids = x[1]
char_ids = x[2]
mask = torch.sign(torch.abs(token_ids)).to(torch.uint8).to(self.device)
# mask : [batch_size, seq_size]
lengths = torch.sum(mask.to(torch.long), dim=1)
# lengths : [batch_size]
# 1. Embedding
elmo_embed_out = self.elmo_model(char_ids)['elmo_representations'][0]
# elmo_embed_out : [batch_size, seq_size, elmo_emb_dim]
'''
masks = mask.unsqueeze(2).to(torch.float)
# masks : [batch_size, seq_size, 1]
elmo_embed_out *= masks # auto-braodcasting
'''
token_embed_out = self.embed_token(token_ids)
# token_embed_out : [batch_size, seq_size, token_emb_dim]
pos_embed_out = self.embed_pos(pos_ids)
# pos_embed_out : [batch_size, seq_size, pos_emb_dim]
if self.use_char_cnn:
char_ids = x[2]
# char_ids : [batch_size, seq_size, char_n_ctx]
charcnn_out = self.charcnn(char_ids)
# charcnn_out : [batch_size, seq_size, self.charcnn.last_dim]
embed_out = torch.cat([elmo_embed_out, token_embed_out, pos_embed_out, charcnn_out], dim=-1)
# embed_out : [batch_size, seq_size, emb_dim]
else:
embed_out = torch.cat([elmo_embed_out, token_embed_out, pos_embed_out], dim=-1)
# embed_out : [batch_size, seq_size, emb_dim]
embed_out = self.dropout(embed_out)
# 2. LSTM
packed_embed_out = torch.nn.utils.rnn.pack_padded_sequence(embed_out, lengths, batch_first=True, enforce_sorted=False)
lstm_out, (h_n, c_n) = self.lstm(packed_embed_out)
lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True, total_length=self.seq_size)
# lstm_out : [batch_size, seq_size, lstm_hidden_dim*2]
lstm_out = self.dropout(lstm_out)
# 3. Output
logits = self.linear(lstm_out)
# logits : [batch_size, seq_size, label_size]
if not self.use_crf: return logits
prediction = self.crf.decode(logits)
prediction = torch.as_tensor(prediction, dtype=torch.long)
# prediction : [batch_size, seq_size]
return logits, prediction