-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathprecompute_ref_named_entities_and_create.py
222 lines (192 loc) · 9.17 KB
/
precompute_ref_named_entities_and_create.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
from os.path import exists, join
import json
from time import time
from datetime import timedelta
import multiprocessing as mp
from cytoolz import concat, curry, compose
import argparse
import re
import spacy
import neuralcoref
nlp = spacy.load('en')
neuralcoref.add_to_pipe(nlp)
DATE_TIME_NUMERICAL_ENTITIES_TYPES = ["DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"]
MAX_OUT_LEN = 100
def _count_data(path):
""" count number of data in the given path"""
matcher = re.compile(r'[0-9]+\.json')
match = lambda name: bool(matcher.match(name))
names = os.listdir(path)
n_data = len(list(filter(match, names)))
return n_data
def extract_entities(sent_list):
entity_list = []
num_entities_for_each_sent = [0] * len(sent_list)
for sent_i, sent in enumerate(sent_list):
for sent_spacy in nlp(sent).sents:
for ent in sent_spacy.ents:
if ent.label_ not in DATE_TIME_NUMERICAL_ENTITIES_TYPES:
num_entities_for_each_sent[sent_i] += 1
entity_list.append(ent.text)
return entity_list, num_entities_for_each_sent
def check_present_named_entities(doc_word_list, named_entity_words_list):
entity_start_end_list = []
for entity_words in named_entity_words_list: # for each named entity
# check if it appears in document
match = False
for doc_start_idx in range(len(doc_word_list) - len(entity_words) + 1):
match = True
for entity_word_idx, entity_word in enumerate(entity_words):
doc_word = doc_word_list[doc_start_idx + entity_word_idx]
if doc_word != entity_word:
match = False
break
if match:
break
if match:
entity_start_end_list.append((doc_start_idx, doc_start_idx + len(entity_words)))
else:
entity_start_end_list.append((-1, -1))
return entity_start_end_list
@curry
def process(in_data_dir, out_data_dir, i):
#if True:
try:
with open(join(in_data_dir, '{}.json'.format(i))) as f:
js = json.loads(f.read())
#with open(join(out_data_dir, '{}.json'.format(i))) as f:
#out_js = json.loads(f.read())
doc_sent_list = js['article']
summary_sent_list = js['abstract']
# truncate summary_sent_list
summary_sent_list_trunc = []
#for summary_sent in summary_sent_list:
# summary_sent_word_list_trunc = summary_sent.split(' ')[:MAX_OUT_LEN]
# summary_sent_trunc = ' '.join(summary_sent_word_list_trunc)
# summary_sent_list_trunc.append(summary_sent_trunc)
if doc_sent_list and summary_sent_list:
doc_word_list_lower = ' '.join(doc_sent_list).lower().split(' ')
# truncate summary up to 100 tokens
summary_word_list = ' '.join(summary_sent_list).split(' ')
#print(summary_sent_list)
summary_str = ' '.join(summary_word_list[:100])
#print(summary_str)
# extract coref and named entity:
summary_spacy = nlp(summary_str)
coref_clusters = {}
reference_entity_list_non_numerical = []
processed_entities = []
for ent in summary_spacy.ents:
# check if it is non numerical
if ent.label_ not in DATE_TIME_NUMERICAL_ENTITIES_TYPES and ent.text not in processed_entities:
processed_entities.append(ent.text)
coref = ent._.coref_cluster
if coref is not None:
coref_clusters[ent.text] = [mention.text for mention in coref.mentions]
reference_entity_list_non_numerical.append(ent.text)
else:
coref_clusters[ent.text] = [ent.text]
reference_entity_list_non_numerical.append(ent.text)
# check if present in the first 400 words
reference_named_entity_words_list_lower = [entity_str.lower().split(' ') for entity_str in reference_entity_list_non_numerical]
#print(coref_clusters)
#print(reference_named_entity_words_list_lower)
entity_start_end_list = check_present_named_entities(doc_word_list_lower, reference_named_entity_words_list_lower)
#print(entity_start_end_list)
#for entity_start, entity_end in entity_start_end_list:
# print(doc_word_list_lower[entity_start: entity_end])
# remove entities not exists or not in first 400 words, precompute position_ids
filtered_reference_entity_list_non_numerical = []
filtered_entity_start_end_list = []
for entity_str, (entity_start, entity_end) in zip(reference_entity_list_non_numerical, entity_start_end_list):
if 0 <= entity_end < 400:
filtered_reference_entity_list_non_numerical.append(entity_str)
filtered_entity_start_end_list.append((entity_start, entity_end))
else:
coref_clusters.pop(entity_str, None)
# remove summary sentences with no reference entities or its mentions
if len(filtered_reference_entity_list_non_numerical) > 0:
summary_sent_list_filtered = []
for summary_sent_i, summary_sent in enumerate(summary_sent_list):
match_flag = False
for entity, mentions in coref_clusters.items():
for mention in mentions:
if mention in summary_sent:
match_flag = True
break
if match_flag:
break
if match_flag:
summary_sent_list_filtered.append(summary_sent)
else:
summary_sent_list_filtered = summary_sent_list
#print(filtered_reference_entity_list_non_numerical)
#print(filtered_entity_start_end_list)
#print()
js['reference_entity_list_non_numerical'] = filtered_reference_entity_list_non_numerical
js['reference_entity_start_end_list'] = filtered_entity_start_end_list
if len(filtered_reference_entity_list_non_numerical) > 0:
js['reference_entity_list_non_numerical_str'] = ' <ent> '.join(filtered_reference_entity_list_non_numerical) + ' <ent_end>'
else:
js['reference_entity_list_non_numerical_str'] = ""
js['reference_coref_clusters'] = coref_clusters
js['abstract'] = summary_sent_list_filtered
# remove some unused keys
js.pop("extractive_fragment_density", None)
js.pop("extractive_fragment_coverage", None)
js.pop("similar_source_indices_lebanoff", None)
js.pop("avg_fusion_ratio", None)
js.pop("unique_two_gram_novelty", None)
js.pop("two_gram_novelty", None)
with open(join(out_data_dir, '{}.json'.format(i)), 'w') as f:
json.dump(js, f, indent=4)
else:
js['reference_entity_list_non_numerical'] = []
js['reference_entity_start_end_list'] = []
js['reference_entity_list_non_numerical_str'] = ""
with open(join(out_data_dir, '{}.json'.format(i)), 'w') as f:
json.dump(js, f, indent=4)
except:
#print("json {} failed".format(i))
pass
def label_mp(in_data, out_data, split):
""" process the data split with multi-processing"""
start = time()
print('start processing {} split...'.format(split))
in_data_dir = join(in_data, split)
out_data_dir = join(out_data, split)
os.makedirs(out_data_dir)
n_data = _count_data(in_data_dir)
with mp.Pool() as pool:
list(pool.imap_unordered(process(in_data_dir, out_data_dir),
list(range(n_data)), chunksize=1024))
print('finished in {}'.format(timedelta(seconds=time()-start)))
def label(in_data, out_data, split):
""" process the data split with multi-processing"""
in_data_dir = join(in_data, split)
out_data_dir = join(out_data, split)
os.makedirs(out_data_dir)
n_data = _count_data(in_data_dir)
#n_data = 16
for i in range(n_data):
process(in_data_dir, out_data_dir, i)
def main(in_data, out_data, split):
if split == 'all':
for split in ['val', 'train', 'test']:
label_mp(in_data, out_data, split)
else:
label_mp(in_data, out_data, split)
#label(in_data, out_data, split)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description=('')
)
parser.add_argument('-in_data', type=str, action='store',
help='The directory of the data.')
parser.add_argument('-out_data', type=str, action='store',
help='The directory of the data.')
parser.add_argument('-split', type=str, action='store', default='all',
help='The folder name that needs to produce candidates. all means process both train and val.')
args = parser.parse_args()
main(args.in_data, args.out_data, args.split)