Skip to content

Commit

Permalink
restart halfflip with correct checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyotech committed Feb 11, 2023
1 parent b131e88 commit f6e93dc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 18 deletions.
22 changes: 10 additions & 12 deletions learner_half_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@
ent_coef=0.01,
)

run_id = "halfflip_run1.00"
run_id = "halfflip_run2.00"
wandb.login(key=os.environ["WANDB_KEY"])
logger = wandb.init(dir="./wandb_store",
name="Halfflip_Run1.00",
name="Halfflip_Run2.00",
project="Opti",
entity="kaiyotech",
id=run_id,
Expand Down Expand Up @@ -86,21 +86,18 @@
add_fliptime=True,
add_boosttime=True,
add_handbrake=True,
flip_dir=False),
),
lambda: ZeroSumReward(zero_sum=Constants_half_flip.ZERO_SUM,
velocity_pb_w=0.01,
boost_gain_w=0.35,
boost_spend_w=3,
punish_boost=True,
touch_ball_w=2,
boost_remain_touch_w=1.5,
touch_grass_w=-0.01,
supersonic_bonus_vpb_w=0,
zero_touch_grass_if_ss=False,
turtle_w=0,
touch_grass_w=-0.005,
final_reward_ball_dist_w=1,
final_reward_boost_w=0.2,
tick_skip=frame_skip
tick_skip=frame_skip,
),
lambda: CoyoteAction(),
save_every=logger.config.save_every * 3,
Expand All @@ -112,11 +109,11 @@
max_age=1,
)

critic = Sequential(Linear(227, 256), LeakyReLU(), Linear(256, 256), LeakyReLU(),
critic = Sequential(Linear(229, 256), LeakyReLU(), Linear(256, 256), LeakyReLU(),
Linear(256, 128), LeakyReLU(),
Linear(128, 1))

actor = Sequential(Linear(227, 128), LeakyReLU(), Linear(128, 128), LeakyReLU(),
actor = Sequential(Linear(229, 128), LeakyReLU(), Linear(128, 128), LeakyReLU(),
Linear(128, 128), LeakyReLU(),
Linear(128, 373))

Expand Down Expand Up @@ -146,10 +143,11 @@

)

alg.load("recovery_saves/Opti_1675223912.912102/Opti_430/checkpoint.pt")
# alg.load("recovery_saves/Opti_1675223912.912102/Opti_430/checkpoint.pt")
alg.load("recovery_saves/Opti_1675355797.047272/Opti_680/checkpoint.pt")
alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr
alg.agent.optimizer.param_groups[1]["lr"] = logger.config.critic_lr

alg.freeze_policy(20)
# alg.freeze_policy(20)

alg.run(iterations_per_save=logger.config.save_every, save_dir="half_flip_saves")
9 changes: 3 additions & 6 deletions worker_half_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,10 @@
punish_boost=True,
touch_ball_w=2,
boost_remain_touch_w=1.5,
touch_grass_w=-0.01,
supersonic_bonus_vpb_w=0,
zero_touch_grass_if_ss=False,
turtle_w=0,
touch_grass_w=-0.005,
final_reward_ball_dist_w=1,
final_reward_boost_w=0.2,
tick_skip=frame_skip
tick_skip=frame_skip,
)

fps = 120 // frame_skip
Expand Down Expand Up @@ -84,7 +81,7 @@
add_fliptime=True,
add_boosttime=True,
add_handbrake=True,
flip_dir=False),
),
action_parser=CoyoteAction(),
terminal_conditions=[GoalScoredCondition(),
TimeoutCondition(fps * 100),
Expand Down

0 comments on commit f6e93dc

Please sign in to comment.