-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpredict_span.py
52 lines (44 loc) · 2.24 KB
/
predict_span.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from argparse import ArgumentParser
from sftp import SpanPredictor
parser = ArgumentParser('predict spans')
# For local experience, you may download the file via the following url and point the `-m` argument
# to the downloaded file (you may even extract it out for quickest loading).
parser.add_argument(
'-m', help='model path', type=str, default='https://public.gqin.me/framenet/20210127.fn.tar.gz'
)
args = parser.parse_args()
# Specify the path to the model and the device that the model resides.
# Here we use -1 device, which indicates CPU.
predictor = SpanPredictor.from_path(
args.m,
cuda_device=-1,
)
# Input sentence could be a string. It will be tokenized by SpacyTokenizer, and the tokens will be returned
# along with the predictions.
input1 = "Bob saw Alice eating an apple."
print("Example 1 with input:", input1)
output1 = predictor.predict_sentence(input1)
# The `tree` function can print out a human-readable parse tree.
output1.span.tree(output1.sentence)
# Input sentence might already be tokenized. In this situation, we'll respect the tokenization.
# The output will be based on the given tokens.
input2 = ["Bob", "saw", "Alice", "eating", "an", "apple", "."]
print('-'*20+"\nExample 2 with input:", input2)
output2 = predictor.predict_sentence(input2)
output2.span.tree(output2.sentence)
# To be efficient, you can input all the sentences as a whole.
# Note: The predictor will do batching itself.
# Instead of specifying the batch size, you should specify `max_tokens`, which
# indicates the maximum tokens that could be put into one batch.
# The predictor will dynamically batch the input sentences efficiently,
# and the outputs will be in the same order as the inputs.
output3 = predictor.predict_batch_sentences([input1, input2], max_tokens=512, progress=True)
print('-'*20+"\nExample 3 with both inputs:")
for i in range(2):
output3[i].span.tree(output3[i].sentence)
# For SRL, we can limit the decoding depth if we only need the events prediction. (save 13% time)
# And can possibly limit #spans to speedup.
predictor.economize(max_decoding_spans=20, max_recursion_depth=1)
output4 = predictor.predict_batch_sentences([input2], max_tokens=512)
print('-'*20+"\nExample 4 with input:", input2)
output4[0].span.tree(output4[0].sentence)