-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain_DCMM.py
94 lines (84 loc) · 3.58 KB
/
train_DCMM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import annotations
import hydra
import torch
import os
import random
import wandb
from hydra.utils import to_absolute_path
from omegaconf import DictConfig, OmegaConf
from termcolor import cprint
from gym_dcmm.utils.util import omegaconf_to_dict
from gym_dcmm.algs.ppo_dcmm.ppo_dcmm_catch_two_stage import PPO_Catch_TwoStage
from gym_dcmm.algs.ppo_dcmm.ppo_dcmm_catch_one_stage import PPO_Catch_OneStage
from gym_dcmm.algs.ppo_dcmm.ppo_dcmm_track import PPO_Track
import gymnasium as gym
import gym_dcmm
import datetime
import pytz
# os.environ['MUJOCO_GL'] = 'egl'
OmegaConf.register_new_resolver('resolve_default', lambda default, arg: default if arg=='' else arg)
@hydra.main(config_name='config', config_path='configs')
def main(config: DictConfig):
torch.multiprocessing.set_start_method('spawn')
config.test = config.test
model_path = None
if config.task == 'Tracking' and config.checkpoint_tracking:
config.checkpoint_tracking = to_absolute_path(config.checkpoint_tracking)
model_path = config.checkpoint_tracking
elif (config.task == 'Catching_TwoStage' \
or config.task == 'Catching_OneStage') \
and config.checkpoint_catching:
config.checkpoint_catching = to_absolute_path(config.checkpoint_catching)
model_path = config.checkpoint_catching
# use the device for rl
config.rl_device = f'cuda:{config.device_id}' if config.device_id >= 0 else 'cpu'
config.seed = random.seed(config.seed)
cprint('Start Building the Environment', 'green', attrs=['bold'])
# Create and wrap the environment
env_name = 'gym_dcmm/DcmmVecWorld-v0'
task = 'Tracking' if config.task == 'Tracking' else 'Catching'
print("config.num_envs: ", config.num_envs)
env = gym.make_vec(env_name, num_envs=int(config.num_envs),
task=task, camera_name=["top"],
render_per_step=False, render_mode = "rgb_array",
object_name = "object",
img_size = config.train.ppo.img_dim,
imshow_cam = config.imshow_cam,
viewer = config.viewer,
print_obs = False, print_info = False,
print_reward = False, print_ctrl = False,
print_contacts = False, object_eval = config.object_eval,
env_time = 2.5, steps_per_policy = 20)
output_dif = os.path.join('outputs', config.output_name)
# Get the local date and time
local_tz = pytz.timezone('Asia/Shanghai')
current_datetime = datetime.datetime.now().astimezone(local_tz)
current_datetime_str = current_datetime.strftime("%Y-%m-%d/%H:%M:%S")
output_dif = os.path.join(output_dif, current_datetime_str)
os.makedirs(output_dif, exist_ok=True)
PPO = PPO_Track if config.task == 'Tracking' else \
PPO_Catch_TwoStage if config.task == 'Catching_TwoStage' else \
PPO_Catch_OneStage
agent = PPO(env, output_dif, full_config=config)
cprint('Start Training/Testing the Agent', 'green', attrs=['bold'])
if config.test:
if model_path:
print("checkpoint loaded")
agent.restore_test(model_path)
print("testing")
agent.test()
else:
# connect to wandb
wandb.init(
project=config.wandb_project,
entity=config.wandb_entity,
name=config.output_name,
config=omegaconf_to_dict(config),
mode=config.wandb_mode
)
agent.restore_train(model_path)
agent.train()
# close wandb
wandb.finish()
if __name__ == '__main__':
main()