-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
367 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
FRAME_SKIP = 8 | ||
TIME_HORIZON = 6 # horizon in seconds | ||
TIME_HORIZON = 2 # horizon in seconds | ||
T_STEP = FRAME_SKIP / 120 # real time per rollout step | ||
ZERO_SUM = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import wandb | ||
import torch.jit | ||
|
||
from torch.nn import Linear, Sequential, LeakyReLU | ||
|
||
from redis import Redis | ||
from rocket_learn.agent.actor_critic_agent import ActorCriticAgent | ||
from rocket_learn.agent.discrete_policy import DiscretePolicy | ||
from rocket_learn.ppo import PPO | ||
from rocket_learn.rollout_generator.redis.redis_rollout_generator import RedisRolloutGenerator | ||
from CoyoteObs import CoyoteObsBuilder | ||
|
||
from CoyoteParser import CoyoteAction | ||
import numpy as np | ||
from rewards import ZeroSumReward | ||
from Constants_kickoff import FRAME_SKIP, TIME_HORIZON, ZERO_SUM | ||
|
||
from utils.misc import count_parameters | ||
|
||
import os | ||
from torch import set_num_threads | ||
from rocket_learn.utils.stat_trackers.common_trackers import Speed, Demos, TimeoutRate, Touch, EpisodeLength, Boost, \ | ||
BehindBall, TouchHeight, DistToBall, AirTouch, AirTouchHeight, BallHeight, BallSpeed, CarOnGround, GoalSpeed,\ | ||
MaxGoalSpeed | ||
# TODO profile everything before starting to make sure everything is as fast as possible | ||
|
||
# ideas for models: | ||
# get to ball as fast as possible, sometimes with no boost, rewards exist | ||
# pinches (ceiling and kuxir and team?), score in as few touches as possible with high velocity | ||
# half flip, wavedash, wall dash, how to do this one? | ||
# lix reset? | ||
# normal play as well as possible, rewards exist | ||
# aerial play without pinch, rewards exist | ||
# kickoff, 5 second terminal, reward ball distance into opp half | ||
set_num_threads(1) | ||
|
||
if __name__ == "__main__": | ||
frame_skip = FRAME_SKIP | ||
half_life_seconds = TIME_HORIZON | ||
fps = 120 / frame_skip | ||
gamma = np.exp(np.log(0.5) / (fps * half_life_seconds)) | ||
config = dict( | ||
actor_lr=2e-4, | ||
critic_lr=2e-4, | ||
n_steps=100_000, | ||
batch_size=100_000, | ||
minibatch_size=50_000, | ||
epochs=50, | ||
gamma=gamma, | ||
save_every=100, | ||
model_every=1000, | ||
ent_coef=0.01, | ||
) | ||
|
||
run_id = "kickoff_test1" | ||
wandb.login(key=os.environ["WANDB_KEY"]) | ||
logger = wandb.init(dir="./wandb_store", | ||
name="Valger_kickoff", | ||
project="Valger", | ||
entity="kaiyotech", | ||
id=run_id, | ||
config=config, | ||
settings=wandb.Settings(_disable_stats=True, _disable_meta=True), | ||
) | ||
redis = Redis(username="user1", password=os.environ["redis_user1_key"], db=1) # host="192.168.0.201", | ||
redis.delete("worker-ids") | ||
|
||
stat_trackers = [ | ||
Speed(normalize=True), Demos(), TimeoutRate(), Touch(), EpisodeLength(), Boost(), BehindBall(), TouchHeight(), | ||
DistToBall(), AirTouch(), AirTouchHeight(), BallHeight(), BallSpeed(normalize=True), CarOnGround(), | ||
GoalSpeed(), MaxGoalSpeed(), | ||
] | ||
|
||
rollout_gen = RedisRolloutGenerator("Valger_kickoff", | ||
redis, | ||
lambda: CoyoteObsBuilder(expanding=True, tick_skip=FRAME_SKIP, team_size=3), | ||
lambda: ZeroSumReward(zero_sum=ZERO_SUM, | ||
goal_w=10, | ||
concede_w=-10, | ||
velocity_pb_w=0.01, | ||
boost_gain_w=1, | ||
demo_w=5, | ||
got_demoed_w=-5, | ||
kickoff_w=0.1, | ||
ball_opp_half_w=0.05, | ||
team_spirit=0), | ||
lambda: CoyoteAction(), | ||
save_every=logger.config.save_every, | ||
model_every=logger.config.model_every, | ||
logger=logger, | ||
clear=True, # TODO check this | ||
stat_trackers=stat_trackers, | ||
# gamemodes=("1v1", "2v2", "3v3"), | ||
max_age=1, | ||
) | ||
|
||
critic = Sequential(Linear(247, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(), | ||
|
||
Linear(512, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(), Linear(512, 512), | ||
LeakyReLU(), Linear(512, 512), LeakyReLU(), | ||
Linear(512, 1)) | ||
|
||
actor = Sequential(Linear(247, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(), | ||
Linear(512, 512), LeakyReLU(), Linear(512, 91)) | ||
|
||
actor = DiscretePolicy(actor, (91,)) | ||
|
||
optim = torch.optim.Adam([ | ||
{"params": actor.parameters(), "lr": logger.config.actor_lr}, | ||
{"params": critic.parameters(), "lr": logger.config.critic_lr} | ||
]) | ||
|
||
agent = ActorCriticAgent(actor=actor, critic=critic, optimizer=optim) | ||
print(f"Gamma is: {gamma}") | ||
count_parameters(agent) | ||
|
||
alg = PPO( | ||
rollout_gen, | ||
agent, | ||
ent_coef=logger.config.ent_coef, | ||
n_steps=logger.config.n_steps, | ||
batch_size=logger.config.batch_size, | ||
minibatch_size=logger.config.minibatch_size, | ||
epochs=logger.config.epochs, | ||
gamma=logger.config.gamma, | ||
logger=logger, | ||
zero_grads_with_none=True, | ||
disable_gradient_logging=True, | ||
) | ||
|
||
# alg.load("model_saves/") | ||
alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr | ||
alg.agent.optimizer.param_groups[1]["lr"] = logger.config.critic_lr | ||
|
||
alg.run(iterations_per_save=logger.config.save_every, save_dir="kickoff_saves") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import numpy as np | ||
import os | ||
from rlgym.utils.common_values import CEILING_Z, BALL_RADIUS, GOAL_HEIGHT | ||
from rlgym.utils.math import euler_to_rotation, cosine_similarity | ||
|
||
|
||
# curate aerial states with ball and at least one car above 750 | ||
def parse_aerial(file_name, _num_cars): | ||
data = np.load(file_name) | ||
output = [] | ||
ball_positions = data[:, BALL_POSITION] | ||
for _i, ball_state in enumerate(ball_positions): | ||
if ball_state[2] > GOAL_HEIGHT + 100: | ||
cars = np.split(data[_i][9:], _num_cars) | ||
for _j in range(_num_cars): | ||
car_pos = cars[_j][CAR_POS] | ||
if np.linalg.norm(ball_state - car_pos) < 5 * BALL_RADIUS: | ||
output.append(data[_i]) | ||
print(f"Created {len(output)} aerial states from {file_name}") | ||
output_file = f"aerial_{file_name}" | ||
if os.path.exists(output_file): | ||
os.remove(output_file) | ||
np.save(output_file, output) | ||
|
||
|
||
# curate flip reset states and save in flip_reset_1v1.npy, etc | ||
def parse_flip_resets(file_name, _num_cars): | ||
data = np.load(file_name) | ||
output = [] | ||
ball_positions = data[:, BALL_POSITION] | ||
for _i, ball_state in enumerate(ball_positions): | ||
if ball_state[2] > CEILING_Z - ((CEILING_Z - GOAL_HEIGHT) / 2): | ||
cars = np.split(data[_i][9:], _num_cars) | ||
for _j in range(_num_cars): | ||
car_rot = cars[_j][CAR_ROT] | ||
car_theta = euler_to_rotation(car_rot) | ||
car_up = car_theta[:, 2] | ||
car_pos = cars[_j][CAR_POS] | ||
if np.linalg.norm(ball_state - car_pos) < 3 * BALL_RADIUS \ | ||
and cosine_similarity(ball_state - car_pos, -car_up) > 0.7: | ||
output.append(data[_i]) | ||
print(f"Created {len(output)} flip reset states from {file_name}") | ||
output_file = f"flip_resets_{file_name}" | ||
if os.path.exists(output_file): | ||
os.remove(output_file) | ||
np.save(output_file, output) | ||
|
||
|
||
# curate possible ceiling shot states | ||
def parse_ceiling_shots(file_name, _num_cars): | ||
data = np.load(file_name) | ||
output = [] | ||
up = [0, 0, 1] | ||
for _i, state in enumerate(data): | ||
cars = np.split(state[9:], _num_cars) | ||
for _j in range(_num_cars): | ||
car_rot = cars[_j][CAR_ROT] | ||
car_theta = euler_to_rotation(car_rot) | ||
car_up = car_theta[:, 2] | ||
car_pos = cars[_j][CAR_POS] | ||
if cosine_similarity(up, -car_up) > 0.9 and car_pos[2] > CEILING_Z - 50: | ||
output.append(data[_i]) | ||
|
||
print(f"Created {len(output)} car ceiling states from {file_name}") | ||
output_file = f"flip_resets_{file_name}" | ||
if os.path.exists(output_file): | ||
os.remove(output_file) | ||
np.save(output_file, output) | ||
|
||
|
||
BALL_POSITION = slice(0, 3) | ||
BALL_LIN_VEL = slice(3, 6) | ||
BALL_ANG_VEL = slice(6, 9) | ||
CAR_POS = slice(0, 3) | ||
CAR_ROT = slice(3, 6) | ||
CAR_LIN_VEL = slice(6, 9) | ||
CAR_ANG_VEL = slice(9, 12) | ||
CAR_BOOST = slice(12, 13) | ||
|
||
input_files = ['ssl_1v1.npy', 'ssl_2v2.npy', 'ssl_3v3.npy'] | ||
for i, file in enumerate(input_files): | ||
num_cars = (i + 1) * 2 | ||
parse_aerial(file, num_cars) | ||
parse_flip_resets(file, num_cars) | ||
parse_ceiling_shots(file, num_cars) | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.