generated from ryul99/pytorch-project-template
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrainer.py
165 lines (143 loc) · 5.65 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import argparse
import datetime
import itertools
import os
import random
import traceback
import hydra
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision
import yaml
from tqdm import tqdm
from hydra.core.hydra_config import HydraConfig
from omegaconf import OmegaConf, open_dict
from dataloader.dataloader import DataloaderMode, create_dataloader
from model.model_handler import Model_handler
# from model.network import Network
from utils.utils import get_logger, is_logging_process, set_random_seed, print_config
from utils.writer import Writer
def setup(cfg, rank):
os.environ["MASTER_ADDR"] = cfg.dist.master_addr
os.environ["MASTER_PORT"] = cfg.dist.master_port
timeout_sec = 1800
if cfg.dist.timeout is not None:
os.environ["NCCL_BLOCKING_WAIT"] = "1"
timeout_sec = cfg.dist.timeout
timeout = datetime.timedelta(seconds=timeout_sec)
# initialize the process group
dist.init_process_group(
cfg.dist.mode,
rank=rank,
world_size=cfg.dist.gpus,
timeout=timeout,
)
def cleanup():
dist.destroy_process_group()
def distributed_run(fn, cfg):
mp.spawn(fn, args=(cfg,), nprocs=cfg.dist.gpus, join=True)
def train_loop(rank, cfg):
logger = get_logger(cfg, os.path.basename(__file__))
if cfg.device == "cuda" and cfg.dist.gpus != 0:
cfg.device = rank
# turn off background generator when distributed run is on
cfg.data.use_background_generator = False
setup(cfg, rank)
torch.cuda.set_device(cfg.device)
writer = None
# setup writer
if is_logging_process():
# set log/checkpoint dir
os.makedirs(cfg.log.chkpt_dir, exist_ok=True)
# set writer (tensorboard / wandb)
writer = Writer(cfg, "tensorboard")
if cfg.data.train_dir == "" or cfg.data.test_dir == "":
logger.error("train or test data directory cannot be empty.")
raise Exception("Please specify directories of data")
logger.info("Set up train process")
logger.info("BackgroundGenerator is turned off when Distributed running is on")
# download MNIST dataset before making dataloader
# TODO: This is example code. You should change this part as you need
_ = torchvision.datasets.MNIST(
root=hydra.utils.to_absolute_path("dataset/meta"),
train=True,
transform=torchvision.transforms.ToTensor(),
download=True,
)
_ = torchvision.datasets.MNIST(
root=hydra.utils.to_absolute_path("dataset/meta"),
train=False,
transform=torchvision.transforms.ToTensor(),
download=True,
)
# Sync dist processes (because of download MNIST Dataset)
if cfg.dist.gpus != 0:
dist.barrier()
# make dataloader
if is_logging_process():
logger.info("Making train dataloader...")
train_loader = create_dataloader(cfg, DataloaderMode.train, rank)
if is_logging_process():
logger.info("Making test dataloader...")
test_loader = create_dataloader(cfg, DataloaderMode.test, rank)
# init Model
###############################################################################
# Remark: hydra.utils.instantiate is not for distributed Data Parallel training
# Replace next line with `net_arch = Network(cfg)` and import Network module
###############################################################################
net_arch = hydra.utils.instantiate(cfg.model, cfg=cfg)
# net_arch = Network(cfg)
loss_f = torch.nn.CrossEntropyLoss()
model = Model_handler(cfg, net_arch, loss_f, writer, rank)
# load training state / network checkpoint
if cfg.load.resume_state_path is not None:
model.load_training_state()
elif cfg.load.network_chkpt_path is not None:
model.load_network()
else:
if is_logging_process():
logger.info("Starting new training run.")
try:
if cfg.dist.gpus == 0 or cfg.data.divide_dataset_per_gpu:
epoch_step = 1
else:
epoch_step = cfg.dist.gpus
for epoch in tqdm(range(model.epoch + 1, cfg.train.num_epoch, epoch_step), desc="Epoch", unit='epoch'):
model.epoch = epoch
model.train_model(train_loader)
if model.epoch % cfg.log.chkpt_interval == 0:
model.save_network()
model.save_training_state()
model.test_model(test_loader)
if is_logging_process():
logger.info("End of Train")
except Exception as e:
if is_logging_process():
logger.error(traceback.format_exc())
else:
traceback.print_exc()
finally:
if cfg.dist.gpus != 0:
cleanup()
@hydra.main(version_base="1.1", config_path="config", config_name="default")
def main(hydra_cfg):
hydra_cfg.device = hydra_cfg.device.lower()
with open_dict(hydra_cfg):
hydra_cfg.job_logging_cfg = HydraConfig.get().job_logging
print_config(hydra_cfg, get_logger(hydra_cfg, os.path.basename(__file__), disable_console=True))
# random seed
if hydra_cfg.random_seed is None:
hydra_cfg.random_seed = random.randint(1, 10000)
set_random_seed(hydra_cfg.random_seed)
if hydra_cfg.dist.gpus < 0:
hydra_cfg.dist.gpus = torch.cuda.device_count()
if hydra_cfg.device == "cpu" or hydra_cfg.dist.gpus == 0:
hydra_cfg.dist.gpus = 0
train_loop(0, hydra_cfg)
else:
# because ${hydra:runtime.cwd} is not support for DDP
hydra_cfg.work_dir= hydra_cfg.work_dir
distributed_run(train_loop, hydra_cfg)
if __name__ == "__main__":
main()