-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference_complete.py
104 lines (104 loc) · 3.97 KB
/
inference_complete.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import re
os.environ["CUDA_VISIBLE_DEVICES"]="3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import config
import random
import argparse
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from model import BertSegPos
from evaluate_complete import evaluate
from data_process import load_data
from test_file import generate_file
from transformers import AutoModel
from torch.utils.data import DataLoader
from data_loader import AnChinaDataset
# from metrics import f1_score, bad_case, output_write, output2res
from transformers.optimization import get_cosine_schedule_with_warmup, AdamW
data_dir = './tgt.shuf.seg_pos'
# data_dir = '/EvaHan_testb_raw.txt'
sentences,segs,poss,segpos,flag,gram_list,positions,gram_maxlen,gram2id=load_data(data_dir)
length = len(sentences)
part_length = length
index_train = [True if 0<=j<=1*part_length else False for j in range(length)]
sentences=np.array(sentences,dtype=object)[index_train]
segs=np.array(segs,dtype=object)[index_train]
poss=np.array(poss,dtype=object)[index_train]
segpos=np.array(segpos,dtype=object)[index_train]
flag=np.array(flag,dtype=object)[index_train]
if config.use_attention:
gram_list=np.array(gram_list,dtype=object)[index_train]
positions=np.array(positions,dtype=object)[index_train]
gram_maxlen=np.array(gram_maxlen,dtype=object)[index_train]
else:
gram_list=None
positions=None
gram_maxlen=None
print("load data success!")
# generate_file(test_sentences,test_seg,test_pos,'temp.txt',flag[train_size+test_size:])
test_dataset =AnChinaDataset(sentences,segs,poss,segpos,gram_list,positions,gram_maxlen,gram2id)
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device('cuda:0')
test_loader = DataLoader(test_dataset, batch_size=32,\
collate_fn=test_dataset.collate_fn,num_workers=os.cpu_count(),pin_memory=True)
model=BertSegPos(config,None)
model.to(device)
model.load_state_dict(torch.load('abalation/model_base.pth', map_location="cuda:0"))
print("load model success!")
pred_segs,pred_poss=evaluate(test_loader, model, 'test')
print("predict finish!")
generate_file(sentences, pred_segs, pred_poss, 'abalation/tgt.shuf.seg_pos-base', flag)
# all_data=open('tgt.shuf.seg_pos-infer-repair', 'r', encoding='utf-8').readlines()
# with open("tgt.shuf.seg_pos-reweight-long", 'a', encoding='utf-8-sig') as f:
# for data in all_data:
# if data == '\n':
# continue
# data = data[:-1]
# data += '。/w'
# f.write(data)
# f.write('\n')
# # split_chars = [',',':',',',':','。']
# # length_data = len(data)
# # pre_loc = 0
# # for i in range(length_data):
# # if data[i] in split_chars:
# # f.write(data[pre_loc:i+3])
# # f.write('\n')
# # pre_loc = i+4
# f.close()
# all_data=open("tgt.shuf.seg_pos-reweight-long", 'r', encoding='utf-8').readlines()
# with open("tgt.shuf.seg_pos-reweight0-long", 'a', encoding='utf-8-sig') as f:
# for data in all_data:
# if data == '\n':
# continue
# data = data[:-1]
# word_tags = data.split(' ')
# recheck_data = ''
# for word_tag in word_tags:
# split_word_tag = word_tag.split('/')
# if len(split_word_tag) != 2:
# continue
# word = split_word_tag[0]
# tag = split_word_tag[1]
# if not word or not tag:
# continue
# if len(word)>=5 or len(tag)>2:
# continue
# if tag[0]<'a' or tag[-1]>'z' or tag[0]<'a' or tag[-1]>'z':
# continue
# recheck_data += word_tag
# recheck_data += ' '
# if recheck_data:
# recheck_data = recheck_data[:-1]
# recheck_data += '\n'
# f.write(recheck_data)