Skip to content

Commit

Permalink
setter changes, zero means, add GP files
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyotech committed Oct 12, 2022
1 parent d1c89b4 commit 01d8176
Show file tree
Hide file tree
Showing 11 changed files with 281 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Constants_aerial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FRAME_SKIP = 4
TIME_HORIZON = 8 # horizon in seconds
T_STEP = FRAME_SKIP / 120 # real time per rollout step
ZERO_SUM = False
ZERO_SUM = True
STEP_SIZE = 1_000_000
DB_NUM = 3
2 changes: 1 addition & 1 deletion Constants_ceil_pinch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FRAME_SKIP = 4
TIME_HORIZON = 4 # horizon in seconds
T_STEP = FRAME_SKIP / 120 # real time per rollout step
ZERO_SUM = False
ZERO_SUM = True
STEP_SIZE = 1_000_000
DB_NUM = 2
6 changes: 6 additions & 0 deletions Constants_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
FRAME_SKIP = 8
TIME_HORIZON = 6 # horizon in seconds
T_STEP = FRAME_SKIP / 120 # real time per rollout step
ZERO_SUM = False
STEP_SIZE = 1_000_000
DB_NUM = 4
6 changes: 2 additions & 4 deletions learner_aerial.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
wall_touch_w=0.5,
),
lambda: CoyoteAction(),
save_every=logger.config.save_every,
save_every=logger.config.save_every * 3,
model_every=logger.config.model_every,
logger=logger,
clear=False,
Expand Down Expand Up @@ -131,9 +131,7 @@
disable_gradient_logging=True,
)

alg.load("pinch_saves/Opti_1665372734.654689/Opti_7070/checkpoint.pt")
alg.total_steps = 0
alg.starting_iteration = 0
alg.load("aerial_saves/Opti_1665434098.5206559/Opti_80/checkpoint.pt")
alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr
alg.agent.optimizer.param_groups[1]["lr"] = logger.config.critic_lr

Expand Down
2 changes: 1 addition & 1 deletion learner_ceil_pinch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
touch_height_exp=1.3
),
lambda: CoyoteAction(),
save_every=logger.config.save_every,
save_every=logger.config.save_every * 3,
model_every=logger.config.model_every,
logger=logger,
clear=False,
Expand Down
138 changes: 138 additions & 0 deletions learner_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import wandb
import torch.jit

from torch.nn import Linear, Sequential, LeakyReLU

from redis import Redis
from rocket_learn.agent.actor_critic_agent import ActorCriticAgent
from rocket_learn.agent.discrete_policy import DiscretePolicy
from rocket_learn.ppo import PPO
from rocket_learn.rollout_generator.redis.redis_rollout_generator import RedisRolloutGenerator
from CoyoteObs import CoyoteObsBuilder

from CoyoteParser import CoyoteAction
import numpy as np
from rewards import ZeroSumReward
import Constants_gp

from utils.misc import count_parameters

import os
from torch import set_num_threads
from rocket_learn.utils.stat_trackers.common_trackers import Speed, Demos, TimeoutRate, Touch, EpisodeLength, Boost, \
BehindBall, TouchHeight, DistToBall, AirTouch, AirTouchHeight, BallHeight, BallSpeed, CarOnGround, GoalSpeed,\
MaxGoalSpeed

# ideas for models:
# get to ball as fast as possible, sometimes with no boost, rewards exist
# pinches (ceiling and kuxir and team?), score in as few touches as possible with high velocity
# half flip, wavedash, wall dash, how to do this one?
# lix reset?
# normal play as well as possible, rewards exist
# aerial play without pinch, rewards exist
# kickoff, 5 second terminal, reward ball distance into opp half
set_num_threads(1)

if __name__ == "__main__":
frame_skip = Constants_gp.FRAME_SKIP
half_life_seconds = Constants_gp.TIME_HORIZON
fps = 120 / frame_skip
gamma = np.exp(np.log(0.5) / (fps * half_life_seconds))
config = dict(
actor_lr=1e-4,
critic_lr=1e-4,
n_steps=Constants_gp.STEP_SIZE,
batch_size=100_000,
minibatch_size=50_000,
epochs=30,
gamma=gamma,
save_every=10,
model_every=100,
ent_coef=0.01,
)

run_id = "gp_run1"
wandb.login(key=os.environ["WANDB_KEY"])
logger = wandb.init(dir="./wandb_store",
name="GP_Run1",
project="Opti",
entity="kaiyotech",
id=run_id,
config=config,
settings=wandb.Settings(_disable_stats=True, _disable_meta=True),
)
redis = Redis(username="user1", password=os.environ["redis_user1_key"], db=Constants_gp.DB_NUM) # host="192.168.0.201",
redis.delete("worker-ids")

stat_trackers = [
Speed(normalize=True), Demos(), TimeoutRate(), Touch(), EpisodeLength(), Boost(), BehindBall(), TouchHeight(),
DistToBall(), AirTouch(), AirTouchHeight(), BallHeight(), BallSpeed(normalize=True), CarOnGround(),
GoalSpeed(), MaxGoalSpeed(),
]

rollout_gen = RedisRolloutGenerator("Opti_GP",
redis,
lambda: CoyoteObsBuilder(expanding=True, tick_skip=Constants_gp.FRAME_SKIP,
team_size=3, extra_boost_info=False),
lambda: ZeroSumReward(zero_sum=Constants_gp.ZERO_SUM,
goal_w=0,
aerial_goal_w=5,
double_tap_w=10,
flip_reset_w=5,
flip_reset_goal_w=10,
punish_ceiling_pinch_w=-2,
concede_w=-10,
velocity_bg_w=0.25,
acel_ball_w=1,
team_spirit=0,
cons_air_touches_w=2,
jump_touch_w=1,
wall_touch_w=0.5,
),
lambda: CoyoteAction(),
save_every=logger.config.save_every * 3,
model_every=logger.config.model_every,
logger=logger,
clear=False,
stat_trackers=stat_trackers,
# gamemodes=("1v1", "2v2", "3v3"),
max_age=1,
)

critic = Sequential(Linear(222, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(),
Linear(512, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(),
Linear(512, 1))

actor = Sequential(Linear(222, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(),
Linear(512, 373))

actor = DiscretePolicy(actor, (373,))

optim = torch.optim.Adam([
{"params": actor.parameters(), "lr": logger.config.actor_lr},
{"params": critic.parameters(), "lr": logger.config.critic_lr}
])

agent = ActorCriticAgent(actor=actor, critic=critic, optimizer=optim)
print(f"Gamma is: {gamma}")
count_parameters(agent)

alg = PPO(
rollout_gen,
agent,
ent_coef=logger.config.ent_coef,
n_steps=logger.config.n_steps,
batch_size=logger.config.batch_size,
minibatch_size=logger.config.minibatch_size,
epochs=logger.config.epochs,
gamma=logger.config.gamma,
logger=logger,
zero_grads_with_none=True,
disable_gradient_logging=True,
)

# alg.load("aerial_saves/Opti_1665434098.5206559/Opti_80/checkpoint.pt")
alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr
alg.agent.optimizer.param_groups[1]["lr"] = logger.config.critic_lr

alg.run(iterations_per_save=logger.config.save_every, save_dir="GP_saves")
2 changes: 1 addition & 1 deletion mybots_statesets.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def reset(self, state_wrapper: StateWrapper):

car_attack.set_pos(*desired_car_pos)
car_attack.set_rot(*desired_rotation)
car_attack.boost = 100
car_attack.boost = rand.uniform(0.3, 1.0)

car_attack.set_lin_vel(0, orange_fix * 200 * x_choice, rng.uniform(1375, 1425))
car_attack.set_ang_vel(0, 0, 0)
Expand Down
18 changes: 9 additions & 9 deletions setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def __init__(self, mode):
WeightedSampleSetter(
(
DefaultState(),
AugmentSetter(ReplaySetter(replays[i])),
AugmentSetter(ReplaySetter(aerial_replays[i])),
AugmentSetter(ReplaySetter(replays[i], random_boost=True)),
AugmentSetter(ReplaySetter(aerial_replays[i], random_boost=True)),
AugmentSetter(GroundAirDribble(), True, False, False),
AugmentSetter(WallDribble(), True, False, False),
AugmentSetter(RandomState(cars_on_ground=True)),
AugmentSetter(RandomState(cars_on_ground=False)),
),
# (0.05, 0.50, 0.20, 0.20, 0.025, 0.025)
(0.5, 0.2, 0.1, 0, 0, 0.1, 0.1)
(0.3, 0.45, 0.1, 0.025, 0.025, 0.05, 0.05)
)
)
elif mode == "kickoff":
Expand All @@ -54,10 +54,10 @@ def __init__(self, mode):
self.setters.append(
WeightedSampleSetter(
(
AugmentSetter(ReplaySetter(aerial_replays[i])),
AugmentSetter(ReplaySetter(flip_reset_replays[i])),
AugmentSetter(ReplaySetter(ceiling_replays[i])),
AugmentSetter(ReplaySetter(air_dribble_replays[i])),
AugmentSetter(ReplaySetter(aerial_replays[i], random_boost=True)),
AugmentSetter(ReplaySetter(flip_reset_replays[i], random_boost=True)),
AugmentSetter(ReplaySetter(ceiling_replays[i], random_boost=True)),
AugmentSetter(ReplaySetter(air_dribble_replays[i], random_boost=True)),
AugmentSetter(WallDribble(), True, False, False),
AugmentSetter(RandomState(cars_on_ground=False)),
),
Expand All @@ -69,8 +69,8 @@ def __init__(self, mode):
self.setters.append(
WeightedSampleSetter(
(
AugmentSetter(ReplaySetter(pinch_replays[i])),
AugmentSetter(ReplaySetter(team_pinch_replays[i])),
AugmentSetter(ReplaySetter(pinch_replays[i], random_boost=True)),
AugmentSetter(ReplaySetter(team_pinch_replays[i], random_boost=True)),
AugmentSetter(WallDribble(), True, False, False),
),
(0.65, 0.10, 0.25)
Expand Down
2 changes: 1 addition & 1 deletion worker_aerial.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
past_version_prob=past_version_prob,
sigma_target=2,
evaluation_prob=evaluation_prob,
force_paging=True,
force_paging=False,
dynamic_gm=dynamic_game,
send_obs=True,
auto_minimize=auto_minimize,
Expand Down
2 changes: 1 addition & 1 deletion worker_ceil_pinch.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
past_version_prob=past_version_prob,
sigma_target=2,
evaluation_prob=evaluation_prob,
force_paging=True,
force_paging=False,
dynamic_gm=dynamic_game,
send_obs=True,
auto_minimize=auto_minimize,
Expand Down
120 changes: 120 additions & 0 deletions worker_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import sys
from redis import Redis
from redis.retry import Retry
from redis.backoff import ExponentialBackoff
from redis.exceptions import ConnectionError, TimeoutError
from rlgym.envs import Match
from CoyoteObs import CoyoteObsBuilder
from rlgym.utils.terminal_conditions.common_conditions import GoalScoredCondition
from mybots_terminals import BallTouchGroundCondition
from rocket_learn.rollout_generator.redis.redis_rollout_worker import RedisRolloutWorker
from CoyoteParser import CoyoteAction
from rewards import ZeroSumReward
from torch import set_num_threads
from setter import CoyoteSetter
import Constants_gp
import os
set_num_threads(1)


if __name__ == "__main__":
rew = ZeroSumReward(zero_sum=Constants_gp.ZERO_SUM,
goal_w=0,
aerial_goal_w=5,
double_tap_w=10,
flip_reset_w=5,
flip_reset_goal_w=10,
punish_ceiling_pinch_w=-2,
concede_w=-10,
velocity_bg_w=0.25,
acel_ball_w=1,
team_spirit=0,
cons_air_touches_w=2,
jump_touch_w=1,
wall_touch_w=0.5)
frame_skip = Constants_gp.FRAME_SKIP
fps = 120 // frame_skip
name = "Default"
send_gamestate = False
streamer_mode = False
local = True
auto_minimize = True
game_speed = 100
evaluation_prob = 0
past_version_prob = 0
deterministic_streamer = True
force_old_deterministic = False
team_size = 3
dynamic_game = True
host = "127.0.0.1"
if len(sys.argv) > 1:
host = sys.argv[1]
if host != "127.0.0.1" and host != "localhost":
local = False
if len(sys.argv) > 2:
name = sys.argv[2]
# if len(sys.argv) > 3 and not dynamic_game:
# team_size = int(sys.argv[3])
if len(sys.argv) > 3:
if sys.argv[3] == 'GAMESTATE':
send_gamestate = True
elif sys.argv[3] == 'STREAMER':
streamer_mode = True
evaluation_prob = 0
game_speed = 1
deterministic_streamer = True
auto_minimize = False

match = Match(
game_speed=game_speed,
spawn_opponents=True,
team_size=team_size,
state_setter=CoyoteSetter(mode="normal"),
obs_builder=CoyoteObsBuilder(expanding=True, tick_skip=Constants_gp.FRAME_SKIP, team_size=team_size,
extra_boost_info=False),
action_parser=CoyoteAction(),
terminal_conditions=[GoalScoredCondition(),
BallTouchGroundCondition(min_time_sec=0,
tick_skip=Constants_gp.FRAME_SKIP,
time_after_ground_sec=1),
],
reward_function=rew,
tick_skip=frame_skip,
)

# local Redis
if local:
r = Redis(host=host,
username="user1",
password=os.environ["redis_user1_key"],
db=Constants_gp.DB_NUM,
)

# remote Redis
else:
# noinspection PyArgumentList
r = Redis(host=host,
username="user1",
password=os.environ["redis_user1_key"],
retry_on_error=[ConnectionError, TimeoutError],
retry=Retry(ExponentialBackoff(cap=10, base=1), 25),
db=Constants_gp.DB_NUM,
)

RedisRolloutWorker(r, name, match,
past_version_prob=past_version_prob,
sigma_target=2,
evaluation_prob=evaluation_prob,
force_paging=False,
dynamic_gm=dynamic_game,
send_obs=True,
auto_minimize=auto_minimize,
send_gamestates=send_gamestate,
gamemode_weights={'1v1': 0.8, '2v2': 0.1, '3v3': 0.1}, # default 1/3
streamer_mode=streamer_mode,
deterministic_streamer=deterministic_streamer,
force_old_deterministic=force_old_deterministic,
# testing
batch_mode=True,
step_size=Constants_gp.STEP_SIZE,
).run()

0 comments on commit 01d8176

Please sign in to comment.