-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel_class.py
89 lines (77 loc) · 3.41 KB
/
model_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn as nn
from transformers import BertModel
# switch_model
class SwitchModel(nn.Module):
def __init__(self, num_labels):
super(SwitchModel, self).__init__()
self.encode = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True)
self.drop_out = nn.Dropout(0.3)
self.l1 = nn.Linear(768, num_labels)
def forward(self, input_ids, attention_masks):
outputs = self.encode(input_ids, attention_masks)
input1 = torch.mean(outputs[2][-2], dim=1)
input1 = self.drop_out(input1)
output1 = self.l1(input1)
return output1
# domain_model
class DomainModel(nn.Module):
def __init__(self, domain_matrix, num_labels):
super(DomainModel, self).__init__()
self.encode = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True)
self.embedding_domain = nn.Embedding.from_pretrained(torch.FloatTensor(domain_matrix))
self.drop_out = nn.Dropout(0.3)
self.gelu = nn.GELU()
self.l1 = nn.Linear(300, 768)
self.l2 = nn.Linear(768*2, num_labels)
self.smax = nn.Softmax(dim=1)
def forward(self, input_ids, attention_masks, domain_ids):
outputs = self.encode(input_ids, attention_masks)
with torch.no_grad():
input2 = self.embedding_domain(domain_ids)
input1 = outputs[2][-2]
input2 = self.l1(input2)
input2 = self.gelu(input2)
input3=torch.unsqueeze(input2, -1)
a = torch.matmul(input1, input3)/28.0
a = self.smax(torch.squeeze(a, -1))
a = torch.unsqueeze(a, -1)
input1 = input1.permute(0, 2, 1)
input1 = torch.matmul(input1, a)
input1 = torch.squeeze(input1,-1)
output = torch.cat((input1, input2), 1)
output = self.drop_out(output)
output = self.l2(output)
return output
# slot action model
class SlotActionModel(nn.Module):
def __init__(self, weights_matrix, domain_matrix, num_labels):
super(SlotActionModel, self).__init__()
self.encode = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True)
self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(weights_matrix))
self.embedding_domain = nn.Embedding.from_pretrained(torch.FloatTensor(domain_matrix))
self.drop_out = nn.Dropout(0.3)
self.gelu = nn.GELU()
self.l1 = nn.Linear(300*2, 768)
self.l2 = nn.Linear(768*2, num_labels)
self.smax = nn.Softmax(dim=1)
def forward(self, input_ids, attention_masks, slot_ids, domain_ids):
outputs = self.encode(input_ids, attention_masks)
with torch.no_grad():
slot_embeddings = self.embedding(slot_ids)
domain_embeddings = self.embedding_domain(domain_ids)
input2 = torch.cat((slot_embeddings, domain_embeddings), 1)
input1 = outputs[2][-2]
input2 = self.l1(input2)
input2 = self.gelu(input2)
input3=torch.unsqueeze(input2, -1)
a = torch.matmul(input1, input3)/28.0
a = self.smax(torch.squeeze(a, -1))
a = torch.unsqueeze(a, -1)
input1 = input1.permute(0, 2, 1)
input1 = torch.matmul(input1, a)
input1 = torch.squeeze(input1,-1)
output = torch.cat((input1, input2), 1)
output = self.drop_out(output)
output = self.l2(output)
return output