Skip to content

Commit

Permalink
adding walldash and object tracking and things such as that. Changing…
Browse files Browse the repository at this point in the history
… recovery temporarily back to non-moving to test for episode length before starting walldash.
  • Loading branch information
Kaiyotech committed Feb 7, 2023
1 parent a9fc712 commit 440474b
Show file tree
Hide file tree
Showing 9 changed files with 569 additions and 65 deletions.
6 changes: 6 additions & 0 deletions Constants_walldash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
FRAME_SKIP = 4
TIME_HORIZON = 2 # horizon in seconds
T_STEP = FRAME_SKIP / 120 # real time per rollout step
ZERO_SUM = False
STEP_SIZE = 500_000
DB_NUM = 9
23 changes: 16 additions & 7 deletions CoyoteObs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ def __init__(self, tick_skip=8, team_size=3, expanding: bool = True, extra_boost
add_fliptime=False,
add_airtime=False,
add_boosttime=False,
dodge_deadzone=0.8
dodge_deadzone=0.8,
end_object: PhysicsObject = None,
):
super().__init__()
self.end_object = end_object
assert add_boosttime == add_airtime == add_fliptime == add_jumptime == add_handbrake, "All timers must match"
self.obs_info = obs_info
# self.obs_output = obs_output
Expand Down Expand Up @@ -156,10 +158,6 @@ def reset(self, initial_state: GameState):
self.env.update_settings(boost_consumption=1)
self.infinite_boost_episode = False

if self.end_object_choice is not None and self.end_object_choice == "random":
self.end_object_tracker += 1
if self.end_object_tracker == 7:
self.end_object_tracker = 0

if self.add_boosttime:
self.boosttimes = np.zeros(
Expand Down Expand Up @@ -208,6 +206,14 @@ def pre_step(self, state: GameState):
for player in state.players:
player.boost_amount /= 1

if self.end_object is not None and \
not (self.end_object.position[0] == self.end_object.position[1] == self.end_object.position[2] == -1):
state.ball.position[0] = self.end_object.position[0]
state.ball.position[1] = self.end_object.position[1]
state.ball.position[2] = self.end_object.position[2]
state.ball.linear_velocity = np.asarray([0, 0, 0])
state.ball.angular_velocity = np.asarray([0, 0, 0])

def _update_timers(self, state: GameState):
current_boosts = state.boost_pads
boost_locs = self.boost_locations
Expand Down Expand Up @@ -820,8 +826,11 @@ def build_obs(self, player: PlayerData, state: GameState, previous_action: np.nd
inverted = False
ball = state.ball

if self.end_object_choice is not None and self.end_object_tracker != 0:
ball.position = self.big_boosts[self.end_object_tracker - 1]
if self.end_object is not None and \
not (self.end_object.position[0] == self.end_object.position[1] == self.end_object.position[2] == -1):
ball.position[0] = self.end_object.position[0]
ball.position[1] = self.end_object.position[1]
ball.position[2] = self.end_object.position[2]
ball.linear_velocity = np.asarray([0, 0, 0])
ball.angular_velocity = np.asarray([0, 0, 0])

Expand Down
29 changes: 29 additions & 0 deletions CoyoteParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,35 @@ def make_lookup_table(version):
actions.append([0, 1, 0, 0, -1, 1, 0, 0])
actions = np.array(actions)

elif version == "test_setter":
# Ground
for throttle in (-1, 0, 0.5, 1):
for steer in (-1, -0.5, 0, 0.5, 1):
for boost in (0, 1):
for handbrake in (0, 1):
if boost == 1 and throttle != 1:
continue
actions.append(
[1, 0, 0, 0, 0, 0, 0, 0])
# Aerial
for pitch in (-0.85, -0.84, -0.83, 0, 0.83, 0.84, 0.85):
for yaw in (0, 0, 0, 0, 0, 0, 0):
for roll in (0, 0, 0):
for jump in (0, 1):
for boost in (0, 1):
if jump == 1 and yaw != 0: # Only need roll for sideflip
continue
if pitch == roll == jump == 0: # Duplicate with ground
continue
# Enable handbrake for potential wavedashes
handbrake = jump == 1 and (
pitch != 0 or yaw != 0 or roll != 0)
actions.append(
[1, 0, 0, 0, 0, 0, 0, 0])
# append stall
actions.append([1, 0, 0, 0, 0, 0, 0, 0])
actions = np.array(actions)

return actions

def get_action_space(self) -> gym.spaces.Space:
Expand Down
151 changes: 151 additions & 0 deletions learner_walldash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import wandb
import torch.jit

from torch.nn import Linear, Sequential, LeakyReLU

from redis import Redis
from rocket_learn.agent.actor_critic_agent import ActorCriticAgent
from rocket_learn.agent.discrete_policy import DiscretePolicy
from rocket_learn.ppo import PPO
from rocket_learn.rollout_generator.redis.redis_rollout_generator import RedisRolloutGenerator
from CoyoteObs import CoyoteObsBuilder

from CoyoteParser import CoyoteAction
import numpy as np
from rewards import ZeroSumReward
import Constants_walldash
from agent import MaskIndices

from utils.misc import count_parameters

import random

import os
from torch import set_num_threads
from rocket_learn.utils.stat_trackers.common_trackers import Speed, Demos, TimeoutRate, Touch, EpisodeLength, Boost, \
BehindBall, TouchHeight, DistToBall, AirTouch, AirTouchHeight, BallHeight, BallSpeed, CarOnGround, GoalSpeed, \
MaxGoalSpeed
from my_stattrackers import GoalSpeedTop5perc

# ideas for models:
# get to ball as fast as possible, sometimes with no boost, rewards exist
# pinches (ceiling and kuxir and team?), score in as few touches as possible with high velocity
# half flip, wavedash, wall dash, how to do this one?
# lix reset?
# normal play as well as possible, rewards exist
# aerial play without pinch, rewards exist
# kickoff, 5 second terminal, reward ball distance into opp half
set_num_threads(1)

if __name__ == "__main__":
frame_skip = Constants_walldash.FRAME_SKIP
half_life_seconds = Constants_walldash.TIME_HORIZON
fps = 120 / frame_skip
gamma = np.exp(np.log(0.5) / (fps * half_life_seconds))
config = dict(
actor_lr=1e-4,
critic_lr=1e-4,
n_steps=Constants_walldash.STEP_SIZE,
batch_size=100_000,
minibatch_size=100_000,
epochs=30,
gamma=gamma,
save_every=10,
model_every=1000,
ent_coef=0.01,
)

run_id = "walldash_run1.00"
wandb.login(key=os.environ["WANDB_KEY"])
logger = wandb.init(dir="./wandb_store",
name="Walldash_Run1.00",
project="Opti",
entity="kaiyotech",
id=run_id,
config=config,
settings=wandb.Settings(_disable_stats=True, _disable_meta=True),
resume=True,
)
redis = Redis(username="user1", password=os.environ["redis_user1_key"],
db=Constants_walldash.DB_NUM) # host="192.168.0.201",
redis.delete("worker-ids")

stat_trackers = [
Speed(normalize=True), Touch(), EpisodeLength(), Boost(),
DistToBall(), CarOnGround(),
]
state = random.getstate()
rollout_gen = RedisRolloutGenerator("Walldash",
redis,
lambda: CoyoteObsBuilder(expanding=True,
tick_skip=Constants_walldash.FRAME_SKIP,
team_size=1, extra_boost_info=False,
embed_players=False,
add_jumptime=True,
add_airtime=True,
add_fliptime=True,
add_boosttime=True,
add_handbrake=True),
lambda: ZeroSumReward(zero_sum=Constants_walldash.ZERO_SUM,
velocity_pb_w=0.02,
boost_gain_w=0.35,
boost_spend_w=4,
punish_boost=True,
touch_ball_w=2.5,
boost_remain_touch_w=2,
final_reward_ball_dist_w=1,
final_reward_boost_w=0.3,
tick_skip=Constants_walldash.FRAME_SKIP,
walldash_w=0.35,
),
lambda: CoyoteAction(),
save_every=logger.config.save_every * 3,
model_every=logger.config.model_every,
logger=logger,
clear=False,
stat_trackers=stat_trackers,
gamemodes=("1v0",),
max_age=1,
)

critic = Sequential(Linear(229, 256), LeakyReLU(), Linear(256, 256), LeakyReLU(),
Linear(256, 128), LeakyReLU(),
Linear(128, 1))

actor = Sequential(Linear(229, 128), LeakyReLU(), Linear(128, 128), LeakyReLU(),
Linear(128, 128), LeakyReLU(),
Linear(128, 373))

actor = DiscretePolicy(actor, (373,))

optim = torch.optim.Adam([
{"params": actor.parameters(), "lr": logger.config.actor_lr},
{"params": critic.parameters(), "lr": logger.config.critic_lr}
])

agent = ActorCriticAgent(actor=actor, critic=critic, optimizer=optim)
print(f"Gamma is: {gamma}")
count_parameters(agent)

alg = PPO(
rollout_gen,
agent,
ent_coef=logger.config.ent_coef,
n_steps=logger.config.n_steps,
batch_size=logger.config.batch_size,
minibatch_size=logger.config.minibatch_size,
epochs=logger.config.epochs,
gamma=logger.config.gamma,
logger=logger,
zero_grads_with_none=True,
disable_gradient_logging=True,

)

alg.load("recovery_saves/Opti_1675569709.6808238/Opti_1630/checkpoint.pt")
alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr
alg.agent.optimizer.param_groups[1]["lr"] = logger.config.critic_lr

# alg.freeze_policy(20)

alg.run(iterations_per_save=logger.config.save_every, save_dir="recovery_ball_saves")
Loading

0 comments on commit 440474b

Please sign in to comment.