-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrun.py
40 lines (38 loc) · 2.53 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from model import CodeBert_Seq2Seq
from utils import set_seed
set_seed()
type = 'EVIL_assembly'
model_type = {'codebert': '/home/yangguang/models/codebert-base', 'pseudocodebert':'/home/yangguang/PycharmProjects/Pretrain-BERT/pseudo-codebert/best',
'fgcodebert':'/home/yangguang/PycharmProjects/Pretrain-BERT/fg-codebert/best'}
# task = 'codebert_raw_nl_2_ip_code'
task = 'pseudocodebert_hybrid_attention_0903'
# 初始化模型
model = CodeBert_Seq2Seq(ip_path = '/home/yangguang/PycharmProjects/CodeBert/model/EVIL_assembly/pseudocodebert_ip_nl_2_ip_code/encoder',
raw_path = '/home/yangguang/PycharmProjects/CodeBert/model/EVIL_assembly/pseudocodebert_raw_nl_2_ip_code/encoder',
decoder_layers = 6, fix_encoder = False, beam_size = 10,
max_source_length = 64, max_target_length = 64,
load_model_path = '/home/yangguang/PycharmProjects/HIP/valid_output/EVIL_assembly/pseudocodebert_hybrid/checkpoint-best-rouge/pytorch_model.bin',
layer_attention = True,
l2_norm=True, fusion=True)
#
# # 模型训练
model.train(train_filename = '/home/yangguang/PycharmProjects/CodeBert/data/EVIL_assembly/w_ip/train_hybrid.csv', train_batch_size = 32,
num_train_epochs = 50, learning_rate = 4e-5,
do_eval = True, dev_filename = '/home/yangguang/PycharmProjects/CodeBert/data/EVIL_assembly/w_ip/test_hybrid.csv',
eval_batch_size = 64, output_dir = 'valid_output/'+type+'/'+task)
#
model = CodeBert_Seq2Seq(ip_path = model_type['pseudocodebert'],
raw_path = model_type['pseudocodebert'],
decoder_layers = 6, fix_encoder = False, beam_size = 10,
max_source_length = 64, max_target_length = 64,
load_model_path = 'valid_output/'+type+'/'+task+'/checkpoint-best-rouge/pytorch_model.bin',
layer_attention = True,
l2_norm=True, fusion=True)
# 模型测试
model.test(test_filename = '/home/yangguang/PycharmProjects/CodeBert/data/EVIL_assembly/w_ip/test_hybrid.csv', test_batch_size = 16,
output_dir = 'test_output/'+type+'/'+task)
#
# 模型推理
# comment = model.predict(source = "get the hexadecimal value of suplX and reverse its order then store the value in rev_suplx",
# similarity='get hexadecimal value of var0 and reverse its order then store value in var')
# print(comment)