-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
137 lines (119 loc) · 4.16 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from deprl.main import set_tensor_device
from myutils import logger
from myutils.trainer import MyTrainer
from deprl import custom_distributed
from deprl.utils import load_checkpoint
import torch
import os
from customreward import showingweight, settingrewardweight, settingtypeweight
import datetime
def main(config,setting = False):
if "cpu_override" in config["tonic"] and config["tonic"]["cpu_override"]:
torch.set_default_device("cpu")
logger.log("Manually forcing CPU run.")
else:
set_tensor_device()
train(config,setting = setting)
def train(config,setting = False):
"""
Trains an agent on an environment. : 라이브러리 참조
"""
tonic_conf = config["tonic"]
# Run the header first, e.g. to load an ML framework.
if "header" in tonic_conf:
exec(tonic_conf["header"])
# In case no env_args are passed via the config
if "env_args" not in config or config["env_args"] is None:
config["env_args"] = {}
# set weight
settingtypeweight(config['weights']['type'],setting=setting)
settingrewardweight(reward_weights = config['weights']['reward'],setting =setting)
showingweight()
# Build the training environment.
_environment = tonic_conf["environment"]
environment = custom_distributed.distribute(
environment=_environment,
tonic_conf=tonic_conf,
env_args=config["env_args"],
)
environment.initialize(seed=tonic_conf["seed"])
# Build the testing environment.
_test_environment = (
tonic_conf["test_environment"]
if "test_environment" in tonic_conf
and tonic_conf["test_environment"] is not None
else _environment
)
test_env_args = (
config["test_env_args"]
if "test_env_args" in config
else config["env_args"]
)
test_environment = custom_distributed.distribute(
environment=_test_environment,
tonic_conf=tonic_conf,
env_args=test_env_args,
parallel=1,
sequential=1,
)
test_environment.initialize(seed=tonic_conf["seed"] + 1000000)
# Build the agent.
if "agent" not in tonic_conf or tonic_conf["agent"] is None:
raise ValueError("No agent specified.")
agent = eval(tonic_conf["agent"])
# Set custom mpo parameters
if "mpo_args" in config:
agent.set_params(**config["mpo_args"])
agent.initialize(
observation_space=environment.observation_space,
action_space=environment.action_space,
seed=tonic_conf["seed"],
)
# Set DEP parameters
if hasattr(agent, "expl") and "DEP" in config:
agent.expl.set_params(config["DEP"])
# Initialize the logger to get paths
logger.initialize(
script_path=__file__,
config=config,
test_env=test_environment,
resume=tonic_conf["resume"],
)
path = logger.get_path()
# Process the checkpoint path same way as in tonic_conf.play
checkpoint_path = os.path.join(path, "checkpoints")
time_dict = {"steps": 0, "epochs": 0, "episodes": 0}
(
_,
checkpoint_path,
loaded_time_dict,
) = load_checkpoint(checkpoint_path, checkpoint="last")
time_dict = time_dict if loaded_time_dict is None else loaded_time_dict
if checkpoint_path:
# Load the logger from a checkpoint.
logger.load(checkpoint_path, time_dict)
# Load the weights of the agent form a checkpoint.
agent.load(checkpoint_path)
# Build the trainer.
trainer = tonic_conf["trainer"] or "tonic_conf.Trainer()"
trainer = eval(trainer)
trainer.initialize(
agent=agent,
environment=environment,
test_environment=test_environment,
full_save=tonic_conf["full_save"],
)
# Run some code before training.
if tonic_conf["before_training"]:
exec(tonic_conf["before_training"])
# Train.
try:
print("Training started.")
trainer.run(config, **time_dict)
print("Training finished.")
except Exception as e:
logger.log(f"trainer failed. Exception: {e}")
raise e
# Run some code after training.
if tonic_conf["after_training"]:
exec(["after_training"])