-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
setter changes, zero means, add GP files
- Loading branch information
Showing
11 changed files
with
281 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
FRAME_SKIP = 4 | ||
TIME_HORIZON = 8 # horizon in seconds | ||
T_STEP = FRAME_SKIP / 120 # real time per rollout step | ||
ZERO_SUM = False | ||
ZERO_SUM = True | ||
STEP_SIZE = 1_000_000 | ||
DB_NUM = 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
FRAME_SKIP = 4 | ||
TIME_HORIZON = 4 # horizon in seconds | ||
T_STEP = FRAME_SKIP / 120 # real time per rollout step | ||
ZERO_SUM = False | ||
ZERO_SUM = True | ||
STEP_SIZE = 1_000_000 | ||
DB_NUM = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
FRAME_SKIP = 8 | ||
TIME_HORIZON = 6 # horizon in seconds | ||
T_STEP = FRAME_SKIP / 120 # real time per rollout step | ||
ZERO_SUM = False | ||
STEP_SIZE = 1_000_000 | ||
DB_NUM = 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import wandb | ||
import torch.jit | ||
|
||
from torch.nn import Linear, Sequential, LeakyReLU | ||
|
||
from redis import Redis | ||
from rocket_learn.agent.actor_critic_agent import ActorCriticAgent | ||
from rocket_learn.agent.discrete_policy import DiscretePolicy | ||
from rocket_learn.ppo import PPO | ||
from rocket_learn.rollout_generator.redis.redis_rollout_generator import RedisRolloutGenerator | ||
from CoyoteObs import CoyoteObsBuilder | ||
|
||
from CoyoteParser import CoyoteAction | ||
import numpy as np | ||
from rewards import ZeroSumReward | ||
import Constants_gp | ||
|
||
from utils.misc import count_parameters | ||
|
||
import os | ||
from torch import set_num_threads | ||
from rocket_learn.utils.stat_trackers.common_trackers import Speed, Demos, TimeoutRate, Touch, EpisodeLength, Boost, \ | ||
BehindBall, TouchHeight, DistToBall, AirTouch, AirTouchHeight, BallHeight, BallSpeed, CarOnGround, GoalSpeed,\ | ||
MaxGoalSpeed | ||
|
||
# ideas for models: | ||
# get to ball as fast as possible, sometimes with no boost, rewards exist | ||
# pinches (ceiling and kuxir and team?), score in as few touches as possible with high velocity | ||
# half flip, wavedash, wall dash, how to do this one? | ||
# lix reset? | ||
# normal play as well as possible, rewards exist | ||
# aerial play without pinch, rewards exist | ||
# kickoff, 5 second terminal, reward ball distance into opp half | ||
set_num_threads(1) | ||
|
||
if __name__ == "__main__": | ||
frame_skip = Constants_gp.FRAME_SKIP | ||
half_life_seconds = Constants_gp.TIME_HORIZON | ||
fps = 120 / frame_skip | ||
gamma = np.exp(np.log(0.5) / (fps * half_life_seconds)) | ||
config = dict( | ||
actor_lr=1e-4, | ||
critic_lr=1e-4, | ||
n_steps=Constants_gp.STEP_SIZE, | ||
batch_size=100_000, | ||
minibatch_size=50_000, | ||
epochs=30, | ||
gamma=gamma, | ||
save_every=10, | ||
model_every=100, | ||
ent_coef=0.01, | ||
) | ||
|
||
run_id = "gp_run1" | ||
wandb.login(key=os.environ["WANDB_KEY"]) | ||
logger = wandb.init(dir="./wandb_store", | ||
name="GP_Run1", | ||
project="Opti", | ||
entity="kaiyotech", | ||
id=run_id, | ||
config=config, | ||
settings=wandb.Settings(_disable_stats=True, _disable_meta=True), | ||
) | ||
redis = Redis(username="user1", password=os.environ["redis_user1_key"], db=Constants_gp.DB_NUM) # host="192.168.0.201", | ||
redis.delete("worker-ids") | ||
|
||
stat_trackers = [ | ||
Speed(normalize=True), Demos(), TimeoutRate(), Touch(), EpisodeLength(), Boost(), BehindBall(), TouchHeight(), | ||
DistToBall(), AirTouch(), AirTouchHeight(), BallHeight(), BallSpeed(normalize=True), CarOnGround(), | ||
GoalSpeed(), MaxGoalSpeed(), | ||
] | ||
|
||
rollout_gen = RedisRolloutGenerator("Opti_GP", | ||
redis, | ||
lambda: CoyoteObsBuilder(expanding=True, tick_skip=Constants_gp.FRAME_SKIP, | ||
team_size=3, extra_boost_info=False), | ||
lambda: ZeroSumReward(zero_sum=Constants_gp.ZERO_SUM, | ||
goal_w=0, | ||
aerial_goal_w=5, | ||
double_tap_w=10, | ||
flip_reset_w=5, | ||
flip_reset_goal_w=10, | ||
punish_ceiling_pinch_w=-2, | ||
concede_w=-10, | ||
velocity_bg_w=0.25, | ||
acel_ball_w=1, | ||
team_spirit=0, | ||
cons_air_touches_w=2, | ||
jump_touch_w=1, | ||
wall_touch_w=0.5, | ||
), | ||
lambda: CoyoteAction(), | ||
save_every=logger.config.save_every * 3, | ||
model_every=logger.config.model_every, | ||
logger=logger, | ||
clear=False, | ||
stat_trackers=stat_trackers, | ||
# gamemodes=("1v1", "2v2", "3v3"), | ||
max_age=1, | ||
) | ||
|
||
critic = Sequential(Linear(222, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(), | ||
Linear(512, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(), | ||
Linear(512, 1)) | ||
|
||
actor = Sequential(Linear(222, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(), | ||
Linear(512, 373)) | ||
|
||
actor = DiscretePolicy(actor, (373,)) | ||
|
||
optim = torch.optim.Adam([ | ||
{"params": actor.parameters(), "lr": logger.config.actor_lr}, | ||
{"params": critic.parameters(), "lr": logger.config.critic_lr} | ||
]) | ||
|
||
agent = ActorCriticAgent(actor=actor, critic=critic, optimizer=optim) | ||
print(f"Gamma is: {gamma}") | ||
count_parameters(agent) | ||
|
||
alg = PPO( | ||
rollout_gen, | ||
agent, | ||
ent_coef=logger.config.ent_coef, | ||
n_steps=logger.config.n_steps, | ||
batch_size=logger.config.batch_size, | ||
minibatch_size=logger.config.minibatch_size, | ||
epochs=logger.config.epochs, | ||
gamma=logger.config.gamma, | ||
logger=logger, | ||
zero_grads_with_none=True, | ||
disable_gradient_logging=True, | ||
) | ||
|
||
# alg.load("aerial_saves/Opti_1665434098.5206559/Opti_80/checkpoint.pt") | ||
alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr | ||
alg.agent.optimizer.param_groups[1]["lr"] = logger.config.critic_lr | ||
|
||
alg.run(iterations_per_save=logger.config.save_every, save_dir="GP_saves") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import sys | ||
from redis import Redis | ||
from redis.retry import Retry | ||
from redis.backoff import ExponentialBackoff | ||
from redis.exceptions import ConnectionError, TimeoutError | ||
from rlgym.envs import Match | ||
from CoyoteObs import CoyoteObsBuilder | ||
from rlgym.utils.terminal_conditions.common_conditions import GoalScoredCondition | ||
from mybots_terminals import BallTouchGroundCondition | ||
from rocket_learn.rollout_generator.redis.redis_rollout_worker import RedisRolloutWorker | ||
from CoyoteParser import CoyoteAction | ||
from rewards import ZeroSumReward | ||
from torch import set_num_threads | ||
from setter import CoyoteSetter | ||
import Constants_gp | ||
import os | ||
set_num_threads(1) | ||
|
||
|
||
if __name__ == "__main__": | ||
rew = ZeroSumReward(zero_sum=Constants_gp.ZERO_SUM, | ||
goal_w=0, | ||
aerial_goal_w=5, | ||
double_tap_w=10, | ||
flip_reset_w=5, | ||
flip_reset_goal_w=10, | ||
punish_ceiling_pinch_w=-2, | ||
concede_w=-10, | ||
velocity_bg_w=0.25, | ||
acel_ball_w=1, | ||
team_spirit=0, | ||
cons_air_touches_w=2, | ||
jump_touch_w=1, | ||
wall_touch_w=0.5) | ||
frame_skip = Constants_gp.FRAME_SKIP | ||
fps = 120 // frame_skip | ||
name = "Default" | ||
send_gamestate = False | ||
streamer_mode = False | ||
local = True | ||
auto_minimize = True | ||
game_speed = 100 | ||
evaluation_prob = 0 | ||
past_version_prob = 0 | ||
deterministic_streamer = True | ||
force_old_deterministic = False | ||
team_size = 3 | ||
dynamic_game = True | ||
host = "127.0.0.1" | ||
if len(sys.argv) > 1: | ||
host = sys.argv[1] | ||
if host != "127.0.0.1" and host != "localhost": | ||
local = False | ||
if len(sys.argv) > 2: | ||
name = sys.argv[2] | ||
# if len(sys.argv) > 3 and not dynamic_game: | ||
# team_size = int(sys.argv[3]) | ||
if len(sys.argv) > 3: | ||
if sys.argv[3] == 'GAMESTATE': | ||
send_gamestate = True | ||
elif sys.argv[3] == 'STREAMER': | ||
streamer_mode = True | ||
evaluation_prob = 0 | ||
game_speed = 1 | ||
deterministic_streamer = True | ||
auto_minimize = False | ||
|
||
match = Match( | ||
game_speed=game_speed, | ||
spawn_opponents=True, | ||
team_size=team_size, | ||
state_setter=CoyoteSetter(mode="normal"), | ||
obs_builder=CoyoteObsBuilder(expanding=True, tick_skip=Constants_gp.FRAME_SKIP, team_size=team_size, | ||
extra_boost_info=False), | ||
action_parser=CoyoteAction(), | ||
terminal_conditions=[GoalScoredCondition(), | ||
BallTouchGroundCondition(min_time_sec=0, | ||
tick_skip=Constants_gp.FRAME_SKIP, | ||
time_after_ground_sec=1), | ||
], | ||
reward_function=rew, | ||
tick_skip=frame_skip, | ||
) | ||
|
||
# local Redis | ||
if local: | ||
r = Redis(host=host, | ||
username="user1", | ||
password=os.environ["redis_user1_key"], | ||
db=Constants_gp.DB_NUM, | ||
) | ||
|
||
# remote Redis | ||
else: | ||
# noinspection PyArgumentList | ||
r = Redis(host=host, | ||
username="user1", | ||
password=os.environ["redis_user1_key"], | ||
retry_on_error=[ConnectionError, TimeoutError], | ||
retry=Retry(ExponentialBackoff(cap=10, base=1), 25), | ||
db=Constants_gp.DB_NUM, | ||
) | ||
|
||
RedisRolloutWorker(r, name, match, | ||
past_version_prob=past_version_prob, | ||
sigma_target=2, | ||
evaluation_prob=evaluation_prob, | ||
force_paging=False, | ||
dynamic_gm=dynamic_game, | ||
send_obs=True, | ||
auto_minimize=auto_minimize, | ||
send_gamestates=send_gamestate, | ||
gamemode_weights={'1v1': 0.8, '2v2': 0.1, '3v3': 0.1}, # default 1/3 | ||
streamer_mode=streamer_mode, | ||
deterministic_streamer=deterministic_streamer, | ||
force_old_deterministic=force_old_deterministic, | ||
# testing | ||
batch_mode=True, | ||
step_size=Constants_gp.STEP_SIZE, | ||
).run() |