Skip to content

Commit

Permalink
add nexto tick skip to test
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyotech committed May 1, 2023
1 parent 19dc8b1 commit a338834
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 3 additions & 3 deletions learner_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@
ent_coef=0.01,
)

run_id = "gp_run3.51"
run_id = "gp_run3.52"
wandb.login(key=os.environ["WANDB_KEY"])
logger = wandb.init(dir="./wandb_store",
name="GP_Run3.51",
name="GP_Run3.52",
project="Opti",
entity="kaiyotech",
id=run_id,
Expand Down Expand Up @@ -154,7 +154,7 @@
disable_gradient_logging=True,
)

alg.load("GP_saves/Opti_1682617673.747051/Opti_40310/checkpoint.pt")
alg.load("GP_saves/Opti_1682795258.7251265/Opti_41230/checkpoint.pt")


alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr
Expand Down
3 changes: 2 additions & 1 deletion pretrained_agents/nexto/nexto_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@


class NextoV2(HardcodedAgent):
def __init__(self, model_string, n_players):
def __init__(self, model_string, n_players, tick_skip_skip=2):
cur_dir = os.path.dirname(os.path.realpath(__file__))
self.actor = torch.jit.load(os.path.join(cur_dir, model_string))
self.obs_builder = Nexto_V2_ObsBuilder(n_players=n_players)
self.previous_action = np.array([0, 0, 0, 0, 0, 0, 0, 0])
self.tick_skip_skip = tick_skip_skip

self._lookup_table = self.make_lookup_table()

Expand Down

0 comments on commit a338834

Please sign in to comment.