-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtraining.py
84 lines (72 loc) · 2.6 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf-8 -*-
# @Time : May 25
# @Author : Xuyang SHEN, Alisdair Cameron, Xinqi Zhu
# @File : training.py
# @IDE: PyCharm Community Edition
import tensorflow as tf
import time
from enet.ENet import *
from data_provider.data_dataset import *
from data_provider.label2tfrecord import *
from config import *
# ---------------------------------------------------------
# run parser
# ---------------------------------------------------------
config = parse_cmd_training_args()
# import config from parser
train_add = config.trainset_address
model_add = config.model_add
epochs = config.num_epochs
batch_size = config.batch_size
buffer_size = config.buffer_size
# ---------------------------------------------------------
# offline label generator
# ---------------------------------------------------------
if config.offline_label_generator == 'yes':
print("---------------------* offline label generator begins *--------------------")
tusimple = TfGenerator(
path=train_add,
js_name=config.ground_truth
).run()
print("***************************************************************************")
print("program begins to training, at", time.ctime(),'\n')
starter = time.time()
# ---------------------------------------------------------
# initial data
# ---------------------------------------------------------
data = TDataset(
json_add=train_add,
num_epochs=epochs,
buffer_size=buffer_size
)
print("\n")
# ---------------------------------------------------------
# tainning logging information
# ---------------------------------------------------------
tf.logging.set_verbosity(tf.logging.INFO)
tensors_to_log = {'mean_iou': 'accuracy'}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=50)
est_config = tf.estimator.RunConfig(
save_checkpoints_secs=20 * 60, # times(second) 60s/step | 9-10 steps/epochs
keep_checkpoint_max=40,
)
# ---------------------------------------------------------
# initialized the model
# ---------------------------------------------------------
model_address = os.path.join(os.getcwd(), model_add)
lane_detect = tf.estimator.Estimator(
model_fn=ENet,
model_dir=model_address,
config=est_config
)
print("---------------------********** training **********s--------------------")
# ---------------------------------------------------------
# training process
lane_detect.train(
input_fn=data.dataset_input_fn,
steps=None,
hooks=[logging_hook]
)
print("All the training is finished", " It takes ",
"%0.2f" % (time.time() - starter), 'seconds')
print("\ntraining model store at: ", model_address)