@@ -74,7 +74,7 @@ def train(*, env_id, num_env, hps, num_timesteps, seed):
74
74
update_ob_stats_every_step = hps .pop ('update_ob_stats_every_step' ),
75
75
int_coeff = hps .pop ('int_coeff' ),
76
76
ext_coeff = hps .pop ('ext_coeff' ),
77
- restore_model = hps .pop ('restore_model ' )
77
+ restore_model_path = hps .pop ('restore_model_path ' )
78
78
)
79
79
agent .start_interaction ([venv ])
80
80
if hps .pop ('update_ob_stats_from_random_agent' ):
@@ -139,7 +139,8 @@ def main():
139
139
parser .add_argument ('--ext_coeff' , type = float , default = 2. )
140
140
parser .add_argument ('--dynamics_bonus' , type = int , default = 0 )
141
141
parser .add_argument ('--save_model' , action = 'store_true' )
142
- parser .add_argument ('--restore_model' , action = 'store_true' )
142
+ parser .add_argument ('--restore_model_path' , type = str , default = '' ,
143
+ help = 'Path to the saved_model dir containing the ppo checkpoint' )
143
144
parser .add_argument ('--experiment' , type = str , default = 'ego' , choices = ['baseline' , 'attention' , 'ego' ])
144
145
145
146
args = parser .parse_args ()
@@ -175,7 +176,7 @@ def main():
175
176
ext_coeff = args .ext_coeff ,
176
177
dynamics_bonus = args .dynamics_bonus ,
177
178
save_model = args .save_model ,
178
- restore_model = args .restore_model
179
+ restore_model_path = args .restore_model_path
179
180
)
180
181
181
182
tf_util .make_session (make_default = True )
0 commit comments