Skip to content

Commit bb173db

Browse files
committed
Added model checkpoint to the training result directory
1 parent d19da6c commit bb173db

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

ppo_agent.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(self, *, scope,
110110
update_ob_stats_every_step=True,
111111
int_coeff=None,
112112
ext_coeff=None,
113-
restore_model=False):
113+
restore_model_path=None):
114114

115115
self.lr = lr
116116
self.ext_coeff = ext_coeff
@@ -201,11 +201,11 @@ def __init__(self, *, scope,
201201
if self.is_log_leader:
202202
tf_util.display_var_info(allvars)
203203

204-
model_path = os.path.join(os.getcwd(), 'saved_model')
204+
model_path = os.path.join(logger.get_dir(), 'saved_model')
205205
self.model_path = os.path.join(model_path, 'ppo.ckpt')
206206

207-
if restore_model:
208-
tf_util.load_state(model_path)
207+
if restore_model_path:
208+
tf_util.load_state(restore_model_path)
209209
else:
210210
#self.activate_graph_debugging()
211211
tf.get_default_session().run(tf.variables_initializer(allvars))

run_atari.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def train(*, env_id, num_env, hps, num_timesteps, seed):
7474
update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
7575
int_coeff=hps.pop('int_coeff'),
7676
ext_coeff=hps.pop('ext_coeff'),
77-
restore_model=hps.pop('restore_model')
77+
restore_model_path=hps.pop('restore_model_path')
7878
)
7979
agent.start_interaction([venv])
8080
if hps.pop('update_ob_stats_from_random_agent'):
@@ -139,7 +139,8 @@ def main():
139139
parser.add_argument('--ext_coeff', type=float, default=2.)
140140
parser.add_argument('--dynamics_bonus', type=int, default=0)
141141
parser.add_argument('--save_model', action='store_true')
142-
parser.add_argument('--restore_model', action='store_true')
142+
parser.add_argument('--restore_model_path', type=str, default='',
143+
help='Path to the saved_model dir containing the ppo checkpoint')
143144
parser.add_argument('--experiment', type=str, default='ego', choices=['baseline', 'attention', 'ego'])
144145

145146
args = parser.parse_args()
@@ -175,7 +176,7 @@ def main():
175176
ext_coeff=args.ext_coeff,
176177
dynamics_bonus = args.dynamics_bonus,
177178
save_model=args.save_model,
178-
restore_model=args.restore_model
179+
restore_model_path=args.restore_model_path
179180
)
180181

181182
tf_util.make_session(make_default=True)

0 commit comments

Comments
 (0)