-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
106 lines (94 loc) · 4.06 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# -*- coding: utf-8 -*-
"""
@CreateTime : 2022/12/28 21:25
@Author : Qingpeng Wen
@File : train.py
@Software : PyCharm
@Framework : Pytorch
@LastModify : 2022/12/28 23:35
"""
import os
import torch
import json
import random
import numpy as np
import torch.optim as optim
from thop import clever_format, profile
from utils.model.module import ModelManager
from utils.data_loader.loader import DatasetManager
from utils.process import Processor
from utils.config import *
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
model_file_path = os.path.join(r"save_cais/model")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if __name__ == "__main__":
args = parser.parse_args()
if not args.do_evaluation:
# Save training and model parameters.
if not os.path.exists(args.save_dir):
os.system("mkdir -p " + args.save_dir)
log_path = os.path.join(args.save_dir, "param.json")
with open(log_path, "w") as fw:
fw.write(json.dumps(args.__dict__, indent=True))
# Fix the random seed of package random.
random.seed(args.random_state)
np.random.seed(args.random_state)
# Fix the random seed of Pytorch when using GPU.
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.random_state)
torch.cuda.manual_seed(args.random_state)
# Fix the random seed of Pytorch when using CPU.
torch.manual_seed(args.random_state)
torch.random.manual_seed(args.random_state)
# Load pre-training model
if os.path.exists(model_file_path):
checkpoint = torch.load(model_file_path, map_location=device)
model = checkpoint['model']
dataset = checkpoint["dataset"]
optimizer = checkpoint["optimizer"]
start_epoch = checkpoint["epoch"]
dataset.show_summary()
model.show_summary()
process = Processor(dataset, model, optimizer, start_epoch, args.batch_size)
print('epoch {}: The pre-training model was successfully loaded!'.format(start_epoch))
else:
# Instantiate a dataset object.
print('No save model will be trained from scratch!')
start_epoch = 0
dataset = DatasetManager(args)
dataset.quick_build()
dataset.show_summary()
model_fn = ModelManager
# Instantiate a network model object.
model = model_fn(
args, len(dataset.char_alphabet),
len(dataset.word_alphabet),
len(dataset.slot_alphabet),
len(dataset.intent_alphabet)
)
model.show_summary()
optimizer = optim.Adam(model.parameters(), lr=dataset.learning_rate, weight_decay=dataset.l2_penalty)
# To train and evaluate the models.
process = Processor(dataset, model, optimizer, start_epoch, args.batch_size)
try:
process.train()
except KeyboardInterrupt:
print("Exiting from training early.")
if not args.do_evaluation:
model = torch.load(os.path.join(args.save_dir, "model/model.pkl"))
dataset = torch.load(os.path.join(args.save_dir, "model/dataset.pkl"))
checkpoint = torch.load(os.path.join(args.save_dir, "model/model_epoch.pkl"), map_location=device)
else:
print("Beginning evaluation:")
model = torch.load(os.path.join(args.save_dir, "model/model.pkl"))
dataset = torch.load(os.path.join(args.save_dir, "model/dataset.pkl"))
print('\nAccepted performance: ' + str(Processor.validate(
model, dataset, args.batch_size * 2)) + " at test dataset;\n")