-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
354 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from pretrained_agents.nexto.nexto_v2 import NextoV2 | ||
from pretrained_agents.KBB.kbb import KBB | ||
|
||
FRAME_SKIP = 4 | ||
TIME_HORIZON = 6 # horizon in seconds | ||
T_STEP = FRAME_SKIP / 120 # real time per rollout step | ||
ZERO_SUM = False | ||
STEP_SIZE = 500_000 | ||
DB_NUM = 15 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import wandb | ||
import torch.jit | ||
|
||
from torch.nn import Linear, Sequential, LeakyReLU | ||
|
||
from redis import Redis | ||
from rocket_learn.agent.actor_critic_agent import ActorCriticAgent | ||
from rocket_learn.agent.discrete_policy import DiscretePolicy | ||
from rocket_learn.ppo import PPO | ||
from rocket_learn.rollout_generator.redis.redis_rollout_generator import RedisRolloutGenerator | ||
from CoyoteObs import CoyoteObsBuilder | ||
|
||
from CoyoteParser import CoyoteAction | ||
import numpy as np | ||
from rewards import ZeroSumReward | ||
import Constants_wall | ||
|
||
from utils.misc import count_parameters | ||
|
||
import os | ||
from torch import set_num_threads | ||
from rocket_learn.utils.stat_trackers.common_trackers import Speed, Demos, TimeoutRate, Touch, EpisodeLength, Boost, \ | ||
BehindBall, TouchHeight, DistToBall, AirTouch, AirTouchHeight, BallHeight, BallSpeed, CarOnGround, GoalSpeed, \ | ||
MaxGoalSpeed | ||
|
||
# ideas for models: | ||
# get to ball as fast as possible, sometimes with no boost, rewards exist | ||
# pinches (ceiling and kuxir and team?), score in as few touches as possible with high velocity | ||
# half flip, wavedash, wall dash, how to do this one? | ||
# lix reset? | ||
# normal play as well as possible, rewards exist | ||
# aerial play without pinch, rewards exist | ||
# kickoff, 5 second terminal, reward ball distance into opp half | ||
set_num_threads(1) | ||
|
||
if __name__ == "__main__": | ||
frame_skip = Constants_wall.FRAME_SKIP | ||
half_life_seconds = Constants_wall.TIME_HORIZON | ||
fps = 120 / frame_skip | ||
gamma = np.exp(np.log(0.5) / (fps * half_life_seconds)) | ||
config = dict( | ||
actor_lr=1e-4, | ||
critic_lr=1e-4, | ||
n_steps=Constants_wall.STEP_SIZE, | ||
batch_size=100_000, | ||
minibatch_size=None, | ||
epochs=30, | ||
gamma=gamma, | ||
save_every=20, | ||
model_every=100, | ||
ent_coef=0.01, | ||
) | ||
|
||
run_id = "wall_run0.00" | ||
wandb.login(key=os.environ["WANDB_KEY"]) | ||
logger = wandb.init(dir="./wandb_store", | ||
name="Wall_Run0.00", | ||
project="Opti", | ||
entity="kaiyotech", | ||
id=run_id, | ||
config=config, | ||
settings=wandb.Settings(_disable_stats=True, _disable_meta=True), | ||
resume=True, | ||
) | ||
redis = Redis(username="user1", password=os.environ["redis_user1_key"], | ||
db=Constants_wall.DB_NUM) # host="192.168.0.201", | ||
redis.delete("worker-ids") | ||
|
||
stat_trackers = [ | ||
TimeoutRate(), Touch(), EpisodeLength(), Boost(), BehindBall(), TouchHeight(), | ||
DistToBall(), AirTouch(), AirTouchHeight(), BallHeight(), BallSpeed(normalize=True), CarOnGround(), | ||
GoalSpeed() | ||
] | ||
|
||
rollout_gen = RedisRolloutGenerator("Opti_Wall", | ||
redis, | ||
lambda: CoyoteObsBuilder(expanding=True, tick_skip=Constants_wall.FRAME_SKIP, | ||
team_size=3, extra_boost_info=False, | ||
embed_players=False, | ||
), | ||
lambda: ZeroSumReward(zero_sum=Constants_wall.ZERO_SUM, | ||
concede_w=-10, | ||
goal_w=10, | ||
velocity_bg_w=0.1, | ||
velocity_pb_w=0.05, | ||
acel_ball_w=2, | ||
jump_touch_w=1, | ||
wall_touch_w=1, | ||
touch_ball_w=0.2, | ||
tick_skip=Constants_wall.FRAME_SKIP, | ||
flatten_wall_height=True, | ||
boost_gain_w=0.75, | ||
boost_spend_w=-0.5, | ||
punish_boost=True, | ||
pun_rew_ball_height_w=0.01, | ||
), | ||
lambda: CoyoteAction(), | ||
save_every=logger.config.save_every * 3, | ||
model_every=logger.config.model_every, | ||
logger=logger, | ||
clear=False, | ||
stat_trackers=stat_trackers, | ||
gamemodes=("1v0", "1v1"), | ||
max_age=1, | ||
) | ||
|
||
critic = Sequential(Linear(226, 256), LeakyReLU(), Linear(256, 256), LeakyReLU(), | ||
Linear(256, 256), LeakyReLU(), | ||
Linear(256, 1)) | ||
|
||
actor = Sequential(Linear(226, 256), LeakyReLU(), Linear(256, 256), LeakyReLU(), | ||
Linear(256, 128), LeakyReLU(), | ||
Linear(128, 373)) | ||
|
||
actor = DiscretePolicy(actor, (373,)) | ||
|
||
optim = torch.optim.Adam([ | ||
{"params": actor.parameters(), "lr": logger.config.actor_lr}, | ||
{"params": critic.parameters(), "lr": logger.config.critic_lr} | ||
]) | ||
|
||
agent = ActorCriticAgent(actor=actor, critic=critic, optimizer=optim) | ||
print(f"Gamma is: {gamma}") | ||
count_parameters(agent) | ||
|
||
alg = PPO( | ||
rollout_gen, | ||
agent, | ||
ent_coef=logger.config.ent_coef, | ||
n_steps=logger.config.n_steps, | ||
batch_size=logger.config.batch_size, | ||
minibatch_size=logger.config.minibatch_size, | ||
epochs=logger.config.epochs, | ||
gamma=logger.config.gamma, | ||
logger=logger, | ||
zero_grads_with_none=True, | ||
disable_gradient_logging=True, | ||
) | ||
|
||
# alg.load("dtap_saves/Opti_1684953666.264502/Opti_15740/checkpoint.pt") | ||
|
||
alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr | ||
alg.agent.optimizer.param_groups[1]["lr"] = logger.config.critic_lr | ||
|
||
alg.freeze_policy(50) | ||
|
||
alg.run(iterations_per_save=logger.config.save_every, save_dir="Wall_saves") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import sys | ||
from redis import Redis | ||
from redis.retry import Retry # noqa | ||
from redis.backoff import ExponentialBackoff # noqa | ||
from redis.exceptions import ConnectionError, TimeoutError | ||
|
||
from CoyoteObs import CoyoteObsBuilder | ||
from rocket_learn.rollout_generator.redis.redis_rollout_worker import RedisRolloutWorker | ||
from my_matchmaker import MatchmakerWith1v0 | ||
from CoyoteParser import CoyoteAction | ||
from rewards import ZeroSumReward | ||
from torch import set_num_threads | ||
from setter import CoyoteSetter | ||
from mybots_statesets import EndKickoff | ||
from mybots_terminals import BallTouchGroundCondition | ||
import Constants_wall | ||
import os | ||
|
||
from pretrained_agents.necto.necto_v1 import NectoV1 | ||
from pretrained_agents.nexto.nexto_v2 import NextoV2 | ||
from pretrained_agents.KBB.kbb import KBB | ||
|
||
set_num_threads(1) | ||
|
||
if __name__ == "__main__": | ||
|
||
rew = ZeroSumReward(zero_sum=Constants_wall.ZERO_SUM, | ||
concede_w=-10, | ||
goal_w=10, | ||
velocity_bg_w=0.1, | ||
velocity_pb_w=0.05, | ||
acel_ball_w=2, | ||
jump_touch_w=1, | ||
wall_touch_w=1, | ||
touch_ball_w=0.2, | ||
tick_skip=Constants_wall.FRAME_SKIP, | ||
flatten_wall_height=True, | ||
boost_gain_w=0.75, | ||
boost_spend_w=-0.5, | ||
punish_boost=True, | ||
pun_rew_ball_height_w=0.01, | ||
) | ||
frame_skip = Constants_wall.FRAME_SKIP | ||
fps = 120 // frame_skip | ||
name = "Default" | ||
send_gamestate = False | ||
streamer_mode = False | ||
local = True | ||
auto_minimize = True | ||
game_speed = 100 | ||
evaluation_prob = 0 | ||
past_version_prob = 0 # 0.5 # 0.1 | ||
non_latest_version_prob = [1, 0, 0, | ||
0] # [0.825, 0.0826, 0.0578, 0.0346] # this includes past_version and pretrained | ||
deterministic_streamer = True | ||
force_old_deterministic = False | ||
gamemode_weights = {'1v0': 0.3, '1v1': 0.7} | ||
visualize = False | ||
simulator = True | ||
batch_mode = True | ||
team_size = 3 | ||
dynamic_game = True | ||
infinite_boost_odds = 0.1 | ||
host = "127.0.0.1" | ||
epic_rl_exe_path = None # "D:/Program Files/Epic Games/rocketleague_old/Binaries/Win64/RocketLeague.exe" | ||
|
||
matchmaker = MatchmakerWith1v0() | ||
|
||
if len(sys.argv) > 1: | ||
host = sys.argv[1] | ||
if host != "127.0.0.1" and host != "localhost": | ||
local = False | ||
batch_mode = False | ||
epic_rl_exe_path = None | ||
if len(sys.argv) > 2: | ||
name = sys.argv[2] | ||
# if len(sys.argv) > 3 and not dynamic_game: | ||
# team_size = int(sys.argv[3]) | ||
if len(sys.argv) > 3: | ||
if sys.argv[3] == 'GAMESTATE': | ||
send_gamestate = True | ||
elif sys.argv[3] == 'STREAMER': | ||
streamer_mode = True | ||
evaluation_prob = 0 | ||
game_speed = 1 | ||
auto_minimize = False | ||
infinite_boost_odds = 0.1 | ||
simulator = False | ||
past_version_prob = 0 | ||
|
||
# gamemode_weights = {'1v0': 0.4, '1v1': 0.6} | ||
|
||
elif sys.argv[3] == 'VISUALIZE': | ||
visualize = True | ||
|
||
if simulator: | ||
from rlgym_sim.envs import Match as Sim_Match | ||
from rlgym_sim.utils.terminal_conditions.common_conditions import GoalScoredCondition, TimeoutCondition, \ | ||
NoTouchTimeoutCondition | ||
else: | ||
from rlgym.envs import Match | ||
from rlgym.utils.terminal_conditions.common_conditions import GoalScoredCondition, TimeoutCondition, \ | ||
NoTouchTimeoutCondition | ||
|
||
setter = CoyoteSetter(mode="normal", simulator=simulator) | ||
|
||
match = Match( | ||
game_speed=game_speed, | ||
spawn_opponents=True, | ||
team_size=team_size, | ||
state_setter=setter, | ||
obs_builder=CoyoteObsBuilder(expanding=True, tick_skip=Constants_wall.FRAME_SKIP, team_size=team_size, | ||
extra_boost_info=False, embed_players=False, | ||
infinite_boost_odds=infinite_boost_odds, | ||
), | ||
action_parser=CoyoteAction(), | ||
terminal_conditions=[GoalScoredCondition(), | ||
TimeoutCondition(fps * 60), | ||
], | ||
reward_function=rew, | ||
tick_skip=frame_skip, | ||
) if not simulator else Sim_Match( | ||
spawn_opponents=True, | ||
team_size=team_size, | ||
state_setter=setter, | ||
obs_builder=CoyoteObsBuilder(expanding=True, tick_skip=Constants_wall.FRAME_SKIP, team_size=team_size, | ||
extra_boost_info=False, embed_players=False, | ||
infinite_boost_odds=infinite_boost_odds, | ||
), | ||
action_parser=CoyoteAction(), | ||
terminal_conditions=[GoalScoredCondition(), | ||
TimeoutCondition(fps * 60), | ||
], | ||
reward_function=rew, | ||
) | ||
|
||
# local Redis | ||
if local: | ||
r = Redis(host=host, | ||
username="user1", | ||
password=os.environ["redis_user1_key"], | ||
db=Constants_wall.DB_NUM, | ||
) | ||
|
||
# remote Redis | ||
else: | ||
# noinspection PyArgumentList | ||
r = Redis(host=host, | ||
username="user1", | ||
password=os.environ["redis_user1_key"], | ||
retry_on_error=[ConnectionError, TimeoutError], | ||
retry=Retry(ExponentialBackoff(cap=10, base=1), 25), | ||
db=Constants_wall.DB_NUM, | ||
) | ||
|
||
worker = RedisRolloutWorker(r, name, match, | ||
matchmaker=matchmaker, | ||
sigma_target=1, | ||
evaluation_prob=evaluation_prob, | ||
force_paging=False, | ||
dynamic_gm=dynamic_game, | ||
send_obs=True, | ||
auto_minimize=auto_minimize, | ||
send_gamestates=send_gamestate, | ||
gamemode_weights=gamemode_weights, # default 1/3 | ||
streamer_mode=streamer_mode, | ||
deterministic_streamer=deterministic_streamer, | ||
force_old_deterministic=force_old_deterministic, | ||
# testing | ||
batch_mode=batch_mode, | ||
step_size=Constants_wall.STEP_SIZE, | ||
# full_team_evaluations=True, | ||
epic_rl_exe_path=epic_rl_exe_path, | ||
simulator=simulator, | ||
visualize=visualize, | ||
live_progress=False, | ||
tick_skip=Constants_wall.FRAME_SKIP | ||
) | ||
|
||
worker.env._match._obs_builder.env = worker.env # noqa | ||
if simulator and visualize: | ||
from rocketsimvisualizer import VisualizerThread | ||
|
||
arena = worker.env._game.arena # noqa | ||
v = VisualizerThread(arena, fps=60, tick_rate=120, tick_skip=frame_skip, step_arena=False, # noqa | ||
overwrite_controls=False) # noqa | ||
v.start() | ||
|
||
worker.run() |