-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathner_demo.py
182 lines (126 loc) · 4.98 KB
/
ner_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
from allennlp.data import Instance
import itertools
from allennlp.data.tokenizers import Token
from allennlp.data.fields import Field, TextField, SequenceLabelField
from typing import Dict, List, Iterator, Optional
from allennlp.models import Model
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders import TextFieldEmbedder,BasicTextFieldEmbedder
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.training.metrics import SpanBasedF1Measure
import torch
import torch.nn as nn
import torch.optim as optim
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.common.file_utils import cached_path
from allennlp.modules.token_embedders import Embedding
from allennlp.data.iterators import BucketIterator
from allennlp.training.trainer import Trainer
@DatasetReader.register('conll_03_reader')
class CoNLL03DatasetReader(DatasetReader):
def __init__(self,
token_indexers: Dict[str, TokenIndexer] = None,
lazy: bool = False
) -> None:
super().__init__(lazy)
self.token_indexers = token_indexers or {'tokens':SingleIdTokenIndexer()}
def _read(
self,
file_path: str
) -> Iterator[Instance]:
is_divider = lambda line:line.strip() == ''
with open(file_path, 'r') as conll_file:
for divider, lines in itertools.groupby(conll_file,is_divider):
if not divider:
fields = [l.strip().split() for l in lines]
fields = [l for l in zip(*fields)]
tokens, _, _, ner_tags = fields
yield self.text_to_instance(tokens,ner_tags)
def text_to_instance(
self,
words: List[str],
ner_tags: List[str]
) -> Instance:
fields : Dict[str,Field] = {}
tokens = TextField([Token(w) for w in words], self.token_indexers)
fields['tokens'] = tokens
fields['label'] = SequenceLabelField(labels=ner_tags,sequence_field=tokens)
return Instance(fields)
@Model.register('ner_lstm')
class NerLSTM(Model):
def __init__(self,
vocab: Vocabulary,
embedder: TextFieldEmbedder,
encoder: Seq2SeqEncoder
) -> None:
super().__init__(vocab)
self._embedder = embedder
self._encoder = encoder
self._classifier = nn.Linear(
in_features = encoder.get_output_dim(),
out_features = vocab.get_vocab_size('labels')
)
self.f1 = SpanBasedF1Measure(vocab,'labels')
def forward(
self,
tokens: Dict[str, torch.Tensor],
label: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
mask = get_text_field_mask(tokens)
embedded = self._embedder(tokens)
encoded = self._encoder(embedded,mask)
classified = self._classifier(encoded)
self.f1(classified,label,mask)
output : Dict[str, torch.Tensor] = {}
if label is not None:
output['loss'] = sequence_cross_entropy_with_logits(classified,label,mask)
return output
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return self.f1.get_metric(reset)
# reader
reader = CoNLL03DatasetReader()
train_dataset = reader.read(cached_path('https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train'))
validation_dataset = reader.read(cached_path('https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testa'))
vocab = Vocabulary.from_instances(train_dataset+validation_dataset)
# embedding
EMBEDDING_DIM = 50
# use glove
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'), embedding_dim=EMBEDDING_DIM,trainable=False,pretrained_file="(http://nlp.stanford.edu/data/glove.6B.zip)#glove.6B.50d.txt")
word_embeddings = BasicTextFieldEmbedder({'tokens':token_embedding})
# lstm
HIDDEN_DIM = 25
lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM,HIDDEN_DIM,bidirectional=True,batch_first=True))
model = NerLSTM(vocab,word_embeddings,lstm)
# cuda device
if torch.cuda.is_available():
cuda_device = 0
model = model.cuda(cuda_device)
else:
cuda_device = -1
# optimizer
optimizer = optim.Adam(model.parameters(),lr=1e-3)
# iterator
iterator = BucketIterator(batch_size=10,sorting_keys = [('tokens','num_tokens')])
iterator.index_with(vocab)
# trainer
trainer = Trainer(
model = model,
optimizer = optimizer,
iterator = iterator,
train_dataset = train_dataset,
validation_dataset= validation_dataset,
patience= 3,
num_epochs=10,
cuda_device= cuda_device,
validation_metric='-loss',
grad_clipping=5.0
)
# train
trainer.train()
# save model
with open('/tmp/model.th','wb') as f:
torch.save(model.state_dict(),f)
# predictor
# comparison