-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathmain.py
58 lines (44 loc) · 1.43 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
"""
Created on Sun Jul 28 11:07:46 2019
@author: chxy
"""
import torch
from trainer import Trainer
from config import get_config
from utils import prepare_dirs, save_config
from data_loader import get_test_loader, get_train_loader
def main(config):
# ensure directories are setup
prepare_dirs(config)
# ensure reproducibility
#torch.manual_seed(config.random_seed)
kwargs = {}
if config.use_gpu:
#torch.cuda.manual_seed_all(config.random_seed)
kwargs = {'num_workers': config.num_workers, 'pin_memory': config.pin_memory}
#torch.backends.cudnn.deterministic = True
# instantiate data loaders
test_data_loader = get_test_loader(
config.data_dir, config.batch_size, **kwargs
)
if config.is_train:
train_data_loader = get_train_loader(
config.data_dir, config.batch_size,
config.random_seed, config.shuffle, **kwargs
)
data_loader = (train_data_loader, test_data_loader)
else:
data_loader = test_data_loader
# instantiate trainer
trainer = Trainer(config, data_loader)
# either train
if config.is_train:
save_config(config)
trainer.train()
# or load a pretrained model and test
else:
trainer.test()
if __name__ == '__main__':
config, unparsed = get_config()
main(config)