forked from airKlizz/TextSegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathuse.py
33 lines (29 loc) · 1.2 KB
/
use.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import tensorflow as tf
from model.segmenter import Segmenter
import numpy as np
from nltk.tokenize import sent_tokenize
class TextSegmenter():
def __init__(self, model_weights, bidirectional, num_classification_layers, max_sentences=64):
self.model = Segmenter(max_sentences, bidirectional, num_classification_layers)
_ = self.model([['Sentence 0', 'sentence 1', 'sentence 3'], ['Sentence 0', 'sentence 1', 'sentence 3']], prepare_inputs=True)
self.model.load_weights(model_weights)
def segment(self, text):
sentences = sent_tokenize(text)
scores = self.model([sentences], prepare_inputs=True).numpy()[0][:len(sentences)]
results = list(np.argmax(scores, axis=-1))
passages = []
for i, (sentence, result) in enumerate(zip(sentences, results)):
if i == 0:
passage = sentence
continue
if result == 1:
passages.append(passage)
passage = sentence
else:
passage += ' '+sentence
passages.append(passage)
return passages
@staticmethod
def print_passages(passages):
for p in passages:
print(p+'\n')