diff --git a/learner_gp.py b/learner_gp.py index 5124917..50dbba6 100644 --- a/learner_gp.py +++ b/learner_gp.py @@ -52,10 +52,10 @@ ent_coef=0.01, ) - run_id = "gp_run3.51" + run_id = "gp_run3.52" wandb.login(key=os.environ["WANDB_KEY"]) logger = wandb.init(dir="./wandb_store", - name="GP_Run3.51", + name="GP_Run3.52", project="Opti", entity="kaiyotech", id=run_id, @@ -154,7 +154,7 @@ disable_gradient_logging=True, ) - alg.load("GP_saves/Opti_1682617673.747051/Opti_40310/checkpoint.pt") + alg.load("GP_saves/Opti_1682795258.7251265/Opti_41230/checkpoint.pt") alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr diff --git a/pretrained_agents/nexto/nexto_v2.py b/pretrained_agents/nexto/nexto_v2.py index 7d0b0de..45e5bcf 100644 --- a/pretrained_agents/nexto/nexto_v2.py +++ b/pretrained_agents/nexto/nexto_v2.py @@ -14,11 +14,12 @@ class NextoV2(HardcodedAgent): - def __init__(self, model_string, n_players): + def __init__(self, model_string, n_players, tick_skip_skip=2): cur_dir = os.path.dirname(os.path.realpath(__file__)) self.actor = torch.jit.load(os.path.join(cur_dir, model_string)) self.obs_builder = Nexto_V2_ObsBuilder(n_players=n_players) self.previous_action = np.array([0, 0, 0, 0, 0, 0, 0, 0]) + self.tick_skip_skip = tick_skip_skip self._lookup_table = self.make_lookup_table()