Skip to content

Commit

Permalink
recovery changes to be race to moving ball, 1630.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyotech committed Feb 5, 2023
1 parent 5b07386 commit 0ff5bef
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 32 deletions.
12 changes: 6 additions & 6 deletions learner_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
fps = 120 / frame_skip
gamma = np.exp(np.log(0.5) / (fps * half_life_seconds))
config = dict(
actor_lr=3e-5,
critic_lr=3e-5,
actor_lr=9e-5,
critic_lr=9e-5,
n_steps=Constants_recovery.STEP_SIZE,
batch_size=250_000,
minibatch_size=125_000,
Expand All @@ -55,10 +55,10 @@
ent_coef=0.01,
)

run_id = "recovery_run11.02"
run_id = "recovery_run11.03"
wandb.login(key=os.environ["WANDB_KEY"])
logger = wandb.init(dir="./wandb_store",
name="Recovery_Run11.02",
name="Recovery_Run11.03",
project="Opti",
entity="kaiyotech",
id=run_id,
Expand Down Expand Up @@ -173,10 +173,10 @@

)

alg.load("recovery_saves/Opti_1675515807.4002893/Opti_1320/checkpoint.pt")
alg.load("recovery_saves/Opti_1675569709.6808238/Opti_1630/checkpoint.pt")
alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr
alg.agent.optimizer.param_groups[1]["lr"] = logger.config.critic_lr

alg.freeze_policy(20)

alg.run(iterations_per_save=logger.config.save_every, save_dir="recovery_saves")
alg.run(iterations_per_save=logger.config.save_every, save_dir="recovery_ball_saves")
66 changes: 41 additions & 25 deletions mybots_statesets.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ def reset(self, state_wrapper: StateWrapper):


class RecoverySetter(StateSetter):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0, ball_vel_mult=1, ball_zero_z=False):
self.ball_zero_z = ball_zero_z
self.ball_vel_mult = ball_vel_mult
self.zero_boost_weight=zero_boost_weight
self.zero_ball_vel_weight=zero_ball_vel_weight
self.rng = np.random.default_rng()
Expand Down Expand Up @@ -346,14 +348,18 @@ def reset(self, state_wrapper: StateWrapper):
loc = random_valid_loc()
state_wrapper.ball.set_pos(x=loc[0], y=loc[1], z=94)
if self.rng.uniform() > self.zero_ball_vel_weight:
state_wrapper.ball.set_lin_vel(self.rng.uniform(-200, 200), self.rng.uniform(-200, 200), self.rng.uniform(-200, 200))
state_wrapper.ball.set_lin_vel(self.ball_vel_mult * self.rng.uniform(-200, 200),
self.ball_vel_mult * self.rng.uniform(-200, 200),
0 if self.zero_ball_vel_weight else self.rng.uniform(-200, 200))
else:
state_wrapper.ball.set_lin_vel(0, 0, 0)
state_wrapper.ball.set_ang_vel(0, 0, 0)


class HalfFlip(StateSetter):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0, ball_vel_mult=1, ball_zero_z=False):
self.ball_zero_z = ball_zero_z
self.ball_vel_mult = ball_vel_mult
self.zero_ball_vel_weight = zero_ball_vel_weight
self.zero_boost_weight = zero_boost_weight
self.rng = np.random.default_rng()
Expand All @@ -369,9 +375,9 @@ def reset(self, state_wrapper: StateWrapper):
if zero_ball_vel:
state_wrapper.ball.set_lin_vel(0, 0, 0)
else:
state_wrapper.ball.set_lin_vel(self.rng.uniform(-600, 600) if y == 0 and x != 0 else 0,
self.rng.uniform(-600, 600) if x == 0 and y != 0 else 0,
self.rng.uniform(-200, 200))
state_wrapper.ball.set_lin_vel(self.ball_vel_mult * self.rng.uniform(-600, 600) if y == 0 and x != 0 else 0,
self.ball_vel_mult * self.rng.uniform(-600, 600) if x == 0 and y != 0 else 0,
0 if self.zero_ball_vel_weight else self.rng.uniform(-200, 200))
state_wrapper.ball.set_ang_vel(0, 0, 0)
if self.rng.uniform() > self.zero_boost_weight:
boost = self.rng.uniform(0, 1.000001)
Expand All @@ -395,7 +401,9 @@ def reset(self, state_wrapper: StateWrapper):


class Wavedash(StateSetter):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0, ball_vel_mult=1, ball_zero_z=False):
self.ball_zero_z = ball_zero_z
self.ball_vel_mult = ball_vel_mult
self.zero_boost_weight = zero_boost_weight
self.zero_ball_vel_weight = zero_ball_vel_weight
self.rng = np.random.default_rng()
Expand All @@ -411,9 +419,9 @@ def reset(self, state_wrapper: StateWrapper):
if zero_ball_vel:
state_wrapper.ball.set_lin_vel(0, 0, 0)
else:
state_wrapper.ball.set_lin_vel(self.rng.uniform(-600, 600) if y == 0 and x != 0 else 0,
self.rng.uniform(-600, 600) if x == 0 and y != 0 else 0,
self.rng.uniform(-200, 200))
state_wrapper.ball.set_lin_vel(self.ball_vel_mult * self.rng.uniform(-600, 600) if y == 0 and x != 0 else 0,
self.ball_vel_mult * self.rng.uniform(-600, 600) if x == 0 and y != 0 else 0,
0 if self.zero_ball_vel_weight else self.rng.uniform(-200, 200))
state_wrapper.ball.set_ang_vel(0, 0, 0)
if self.rng.uniform() > self.zero_boost_weight:
boost = self.rng.uniform(0, 1.000001)
Expand All @@ -437,7 +445,9 @@ def reset(self, state_wrapper: StateWrapper):


class Chaindash(StateSetter):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0, ball_vel_mult=1, ball_zero_z=False):
self.ball_zero_z = ball_zero_z
self.ball_vel_mult = ball_vel_mult
self.zero_boost_weight = zero_boost_weight
self.zero_ball_vel_weight = zero_ball_vel_weight
self.rng = np.random.default_rng()
Expand All @@ -457,9 +467,9 @@ def reset(self, state_wrapper: StateWrapper):
if zero_ball_vel:
state_wrapper.ball.set_lin_vel(0, 0, 0)
else:
state_wrapper.ball.set_lin_vel(self.rng.uniform(-600, 600) if y == 0 and x != 0 else 0,
self.rng.uniform(-600, 600) if x == 0 and y != 0 else 0,
self.rng.uniform(-200, 200))
state_wrapper.ball.set_lin_vel(self.ball_vel_mult * self.rng.uniform(-600, 600) if y == 0 and x != 0 else 0,
self.ball_vel_mult * self.rng.uniform(-600, 600) if x == 0 and y != 0 else 0,
0 if self.zero_ball_vel_weight else self.rng.uniform(-200, 200))
state_wrapper.ball.set_ang_vel(0, 0, 0)
if self.rng.uniform() > self.zero_boost_weight:
boost = self.rng.uniform(0, 1.000001)
Expand All @@ -486,7 +496,9 @@ def reset(self, state_wrapper: StateWrapper):


class RandomEvenRecovery(StateSetter):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0, ball_vel_mult=1, ball_zero_z=False):
self.ball_zero_z = ball_zero_z
self.ball_vel_mult = ball_vel_mult
self.zero_boost_weight = zero_boost_weight
self.zero_ball_vel_weight = zero_ball_vel_weight
self.rng = np.random.default_rng()
Expand All @@ -510,9 +522,9 @@ def reset(self, state_wrapper: StateWrapper):
if zero_ball_vel:
state_wrapper.ball.set_lin_vel(0, 0, 0)
else:
state_wrapper.ball.set_lin_vel(self.rng.uniform(-600, 600) if y == 0 and x != 0 else 0,
self.rng.uniform(-600, 600) if x == 0 and y != 0 else 0,
self.rng.uniform(-200, 200))
state_wrapper.ball.set_lin_vel(self.ball_vel_mult * self.rng.uniform(-600, 600) if y == 0 and x != 0 else 0,
self.ball_vel_mult * self.rng.uniform(-600, 600) if x == 0 and y != 0 else 0,
0 if self.zero_ball_vel_weight else self.rng.uniform(-200, 200))
state_wrapper.ball.set_ang_vel(0, 0, 0)
if self.rng.uniform() > self.zero_boost_weight:
boost = self.rng.uniform(0, 1.000001)
Expand All @@ -538,7 +550,9 @@ def reset(self, state_wrapper: StateWrapper):


class Curvedash(StateSetter):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0, ball_vel_mult=1, ball_zero_z=False):
self.ball_zero_z = ball_zero_z
self.ball_vel_mult = ball_vel_mult
self.zero_boost_weight = zero_boost_weight
self.zero_ball_vel_weight = zero_ball_vel_weight
self.rng = np.random.default_rng()
Expand All @@ -558,9 +572,9 @@ def reset(self, state_wrapper: StateWrapper):
if zero_ball_vel:
state_wrapper.ball.set_lin_vel(0, 0, 0)
else:
state_wrapper.ball.set_lin_vel(self.rng.uniform(-600, 600) if ball_y == 0 and ball_x != 0 else 0,
self.rng.uniform(-600, 600) if ball_x == 0 and ball_y != 0 else 0,
self.rng.uniform(-200, 200))
state_wrapper.ball.set_lin_vel(self.ball_vel_mult * self.rng.uniform(-600, 600) if ball_y == 0 and ball_x != 0 else 0,
self.ball_vel_mult * self.rng.uniform(-600, 600) if ball_x == 0 and ball_y != 0 else 0,
0 if self.zero_ball_vel_weight else self.rng.uniform(-200, 200))
state_wrapper.ball.set_ang_vel(0, 0, 0)
if self.rng.uniform() > self.zero_boost_weight:
boost = self.rng.uniform(0, 1.000001)
Expand All @@ -585,7 +599,9 @@ def reset(self, state_wrapper: StateWrapper):


class Walldash(StateSetter):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0):
def __init__(self, zero_boost_weight=0, zero_ball_vel_weight=0, ball_vel_mult=1, ball_zero_z=False):
self.ball_zero_z = ball_zero_z
self.ball_vel_mult = ball_vel_mult
self.zero_boost_weight = zero_boost_weight
self.zero_ball_vel_weight = zero_ball_vel_weight
self.rng = np.random.default_rng()
Expand All @@ -602,8 +618,8 @@ def reset(self, state_wrapper: StateWrapper):
state_wrapper.ball.set_lin_vel(0, 0, 0)
else:
state_wrapper.ball.set_lin_vel(0,
self.rng.uniform(-600, 600) if ball_y != 0 else 0,
self.rng.uniform(-200, 200))
self.ball_vel_mult * self.rng.uniform(-600, 600) if ball_y != 0 else 0,
0 if self.zero_ball_vel_weight else self.rng.uniform(-200, 200))
state_wrapper.ball.set_ang_vel(0, 0, 0)
if ball_y >= 0:
ball_sign = 1
Expand Down
16 changes: 16 additions & 0 deletions setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,22 @@ def __init__(self, mode, end_object_choice=None):
(0.1, 0.1, 0.25, 0.25, 0, 0.15, 0.15)
)
)

elif mode == "recovery_ball":
for i in range(3):
self.setters.append(
WeightedSampleSetter(
(HalfFlip(zero_boost_weight=0.7, zero_ball_vel_weight=0, ball_zero_z=True),
Curvedash(zero_boost_weight=0.7, zero_ball_vel_weight=0, ball_zero_z=True),
RandomEvenRecovery(zero_boost_weight=0.7, zero_ball_vel_weight=0, ball_zero_z=True),
Chaindash(zero_boost_weight=0.7, zero_ball_vel_weight=0, ball_zero_z=True),
Walldash(zero_boost_weight=0.7, zero_ball_vel_weight=0, ball_zero_z=True),
Wavedash(zero_boost_weight=0.7, zero_ball_vel_weight=0, ball_zero_z=True),
RecoverySetter(zero_boost_weight=0.7, zero_ball_vel_weight=0, ball_zero_z=True)
),
(0, 0.15, 0.2, 0.2, 0.2, 0.1, 0.15)
)
)
# self.setters.append(
# WeightedSampleSetter(
# (
Expand Down
2 changes: 1 addition & 1 deletion worker_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
game_speed=game_speed,
spawn_opponents=True,
team_size=team_size,
state_setter=CoyoteSetter(mode="recovery"),
state_setter=CoyoteSetter(mode="recovery_ball"),
obs_builder=CoyoteObsBuilder(expanding=True,
tick_skip=Constants_recovery.FRAME_SKIP,
team_size=3, extra_boost_info=False,
Expand Down

0 comments on commit 0ff5bef

Please sign in to comment.