-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathPooling_custom.py
133 lines (107 loc) · 5.51 KB
/
Pooling_custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import json
import os
from typing import Dict, Iterable, List, Tuple, Union
import torch
from torch import Tensor, nn
class Pooling(nn.Module):
"""Performs pooling (max or mean) on the token embeddings.
Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows to use the CLS token if it is returned by the underlying word embedding model.
You can concatenate multiple poolings together.
:param word_embedding_dimension: Dimensions for the word embeddings
:param pooling_mode_cls_token: Use the first token (CLS token) as text representations
:param pooling_mode_max_tokens: Use max in each dimension over all tokens.
:param pooling_mode_mean_tokens: Perform mean-pooling
:param pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but devide by sqrt(input_length).
"""
def __init__(
self,
word_embedding_dimension: int,
pooling_mode_cls_token: bool = False,
pooling_mode_max_tokens: bool = False,
pooling_mode_mean_tokens: bool = True,
pooling_mode_mean_sqrt_len_tokens: bool = False,
pooling_mode_mean_mark_tokens: bool = False,
):
super(Pooling, self).__init__()
self.config_keys = [
'word_embedding_dimension', 'pooling_mode_cls_token',
'pooling_mode_mean_tokens', 'pooling_mode_max_tokens',
'pooling_mode_mean_mark_tokens',
'pooling_mode_mean_sqrt_len_tokens'
]
self.word_embedding_dimension = word_embedding_dimension
self.pooling_mode_cls_token = pooling_mode_cls_token
self.pooling_mode_mean_tokens = pooling_mode_mean_tokens
self.pooling_mode_max_tokens = pooling_mode_max_tokens
self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens
self.pooling_mode_mean_mark_tokens = pooling_mode_mean_mark_tokens
pooling_mode_multiplier = sum([
pooling_mode_cls_token, pooling_mode_max_tokens,
pooling_mode_mean_tokens, pooling_mode_mean_sqrt_len_tokens,
pooling_mode_mean_mark_tokens
])
self.pooling_output_dimension = (pooling_mode_multiplier *
word_embedding_dimension)
def forward(self, features: Dict[str, Tensor]):
token_embeddings = features['token_embeddings']
cls_token = features['cls_token_embeddings']
attention_mask = features['attention_mask']
mark_token_ids = features['mark_token_ids']
seq_length = token_embeddings.shape[1]
## Pooling strategy
output_vectors = []
if self.pooling_mode_cls_token:
output_vectors.append(cls_token)
if self.pooling_mode_max_tokens:
input_mask_expanded = attention_mask.unsqueeze(-1).expand(
token_embeddings.size()).float()
token_embeddings[
input_mask_expanded ==
0] = -1e9 # Set padding tokens to large negative value
max_over_time = torch.max(token_embeddings, 1)[0]
output_vectors.append(max_over_time)
if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens:
token_ids = features['input_ids']
meaningful_token_ids = [
i.index(3) for i in token_ids.cpu().numpy().tolist()
]
for i in meaningful_token_ids:
attention_mask[:, i:] = 0
attention_mask[:, :1] = 0
input_mask_expanded = attention_mask.unsqueeze(-1).expand(
token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded,
1)
#If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
if 'token_weights_sum' in features:
sum_mask = features['token_weights_sum'].unsqueeze(-1).expand(
sum_embeddings.size())
else:
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
if self.pooling_mode_mean_tokens:
output_vectors.append(sum_embeddings / sum_mask)
if self.pooling_mode_mean_sqrt_len_tokens:
output_vectors.append(sum_embeddings / torch.sqrt(sum_mask))
if self.pooling_mode_mean_mark_tokens:
marked_tokens = mark_token_ids.unsqueeze(-1).expand(
token_embeddings.size()).float()
sum_mask = marked_tokens.sum(1)
filtered_features = torch.sum(token_embeddings * marked_tokens, 1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
output_vectors.append(filtered_features / sum_mask)
output_vector = torch.cat(output_vectors, 1)
features.update({'sentence_embedding': output_vector})
return features
def get_sentence_embedding_dimension(self):
return self.pooling_output_dimension
def get_config_dict(self):
return {key: self.__dict__[key] for key in self.config_keys}
def save(self, output_path):
with open(os.path.join(output_path, 'config.json'), 'w') as fOut:
json.dump(self.get_config_dict(), fOut, indent=2)
@staticmethod
def load(input_path):
with open(os.path.join(input_path, 'config.json')) as fIn:
config = json.load(fIn)
return Pooling(**config)