Skip to content

Commit

Permalink
prepping for Lix reset
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyotech committed Feb 9, 2023
1 parent 31d5efa commit 3e79191
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 18 deletions.
6 changes: 6 additions & 0 deletions Constants_lix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
FRAME_SKIP = 4
TIME_HORIZON = 2 # horizon in seconds
T_STEP = FRAME_SKIP / 120 # real time per rollout step
ZERO_SUM = False
STEP_SIZE = 500_000
DB_NUM = 12
170 changes: 170 additions & 0 deletions learner_lix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import wandb
import torch.jit

from torch.nn import Linear, Sequential, LeakyReLU

from redis import Redis
from rocket_learn.agent.actor_critic_agent import ActorCriticAgent
from rocket_learn.agent.discrete_policy import DiscretePolicy
from rocket_learn.ppo import PPO
from rocket_learn.rollout_generator.redis.redis_rollout_generator import RedisRolloutGenerator
from CoyoteObs import CoyoteObsBuilder

from CoyoteParser import CoyoteAction
import numpy as np
from rewards import ZeroSumReward
import Constants_lix
from agent import MaskIndices

from utils.misc import count_parameters

import random

import os
from torch import set_num_threads
from rocket_learn.utils.stat_trackers.common_trackers import Speed, Demos, TimeoutRate, Touch, EpisodeLength, Boost, \
BehindBall, TouchHeight, DistToBall, AirTouch, AirTouchHeight, BallHeight, BallSpeed, CarOnGround, GoalSpeed, \
MaxGoalSpeed
from my_stattrackers import GoalSpeedTop5perc

# ideas for models:
# get to ball as fast as possible, sometimes with no boost, rewards exist
# pinches (ceiling and kuxir and team?), score in as few touches as possible with high velocity
# half flip, wavedash, wall dash, how to do this one?
# lix reset?
# normal play as well as possible, rewards exist
# aerial play without pinch, rewards exist
# kickoff, 5 second terminal, reward ball distance into opp half
set_num_threads(1)

if __name__ == "__main__":
frame_skip = Constants_lix.FRAME_SKIP
half_life_seconds = Constants_lix.TIME_HORIZON
fps = 120 / frame_skip
gamma = np.exp(np.log(0.5) / (fps * half_life_seconds))
config = dict(
actor_lr=1e-4,
critic_lr=1e-4,
n_steps=Constants_lix.STEP_SIZE,
batch_size=100_000,
minibatch_size=None,
epochs=30,
gamma=gamma,
save_every=10,
model_every=1000,
ent_coef=0.01,
)

run_id = "lix_run0.01"
wandb.login(key=os.environ["WANDB_KEY"])
logger = wandb.init(dir="./wandb_store",
name="Lix_Run0.01",
project="Opti",
entity="kaiyotech",
id=run_id,
config=config,
settings=wandb.Settings(_disable_stats=True, _disable_meta=True),
resume=True,
)
redis = Redis(username="user1", password=os.environ["redis_user1_key"],
db=Constants_lix.DB_NUM) # host="192.168.0.201",
redis.delete("worker-ids")

stat_trackers = [
EpisodeLength(), Boost(),
DistToBall(), CarOnGround(),
]
state = random.getstate()
rollout_gen = RedisRolloutGenerator("Lix",
redis,
lambda: CoyoteObsBuilder(expanding=True,
tick_skip=Constants_lix.FRAME_SKIP,
team_size=3, extra_boost_info=False,
embed_players=False,
add_jumptime=True,
add_airtime=True,
add_fliptime=True,
add_boosttime=True,
add_handbrake=True),
lambda: ZeroSumReward(zero_sum=Constants_lix.ZERO_SUM,
velocity_pb_w=0.01,
wall_touch_w=0.5,
tick_skip=Constants_lix.FRAME_SKIP,
curve_wave_zap_dash_w=0.35,
walldash_w=0.35,
flip_reset_w=5,
),
lambda: CoyoteAction(),
save_every=logger.config.save_every * 3,
model_every=logger.config.model_every,
logger=logger,
clear=False,
stat_trackers=stat_trackers,
gamemodes=("1v0",),
max_age=1,
)

# critic = Sequential(Linear(47, 256), LeakyReLU(), Linear(256, 256), LeakyReLU(),
# Linear(256, 128), LeakyReLU(), Linear(128, 128), LeakyReLU(),
# Linear(128, 1))
#
# # mask_array = torch.zeros(222, dtype=torch.bool)
# # mask_array[47:222] = True
# # actor = Sequential(MaskIndices(mask_array), Linear(47, 256), LeakyReLU(), Linear(256, 256), LeakyReLU(), Linear(256, 128), LeakyReLU(),
# # Linear(128, 373))
#
# actor = Sequential(Linear(47, 256), LeakyReLU(), Linear(256, 256), LeakyReLU(),
# Linear(256, 128), LeakyReLU(), Linear(128, 373))
#
# actor = DiscretePolicy(actor, (373,))

# critic = Sequential(Linear(222, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(),
# Linear(512, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(),
# Linear(512, 1))
#
# actor = Sequential(Linear(222, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(),
# Linear(512, 373))
#
# actor = DiscretePolicy(actor, (373,))

critic = Sequential(Linear(229, 256), LeakyReLU(), Linear(256, 256), LeakyReLU(),
Linear(256, 128), LeakyReLU(),
Linear(128, 1))

actor = Sequential(Linear(229, 128), LeakyReLU(), Linear(128, 128), LeakyReLU(),
Linear(128, 128), LeakyReLU(),
Linear(128, 373))

actor = DiscretePolicy(actor, (373,))

optim = torch.optim.Adam([
{"params": actor.parameters(), "lr": logger.config.actor_lr},
{"params": critic.parameters(), "lr": logger.config.critic_lr}
])

agent = ActorCriticAgent(actor=actor, critic=critic, optimizer=optim)
print(f"Gamma is: {gamma}")
count_parameters(agent)

alg = PPO(
rollout_gen,
agent,
ent_coef=logger.config.ent_coef,
n_steps=logger.config.n_steps,
batch_size=logger.config.batch_size,
minibatch_size=logger.config.minibatch_size,
epochs=logger.config.epochs,
gamma=logger.config.gamma,
logger=logger,
zero_grads_with_none=True,
disable_gradient_logging=True,

)

# alg.load("recovery_saves/Opti_1675569709.6808238/Opti_1630/checkpoint.pt")
alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr
alg.agent.optimizer.param_groups[1]["lr"] = logger.config.critic_lr

# alg.freeze_policy(20)

alg.run(iterations_per_save=logger.config.save_every, save_dir="lix_reset_saves")
26 changes: 26 additions & 0 deletions mybots_statesets.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,32 @@ def reset(self, state_wrapper: StateWrapper):
car.boost = boost


class LixSetter(StateSetter):
def __init__(self):
self.rng = np.random.default_rng()

def reset(self, state_wrapper: StateWrapper):
assert len(state_wrapper.cars) == 1
neg = self.rng.choice([-1, 1])

y = self.rng.uniform(-3000, 2000)
x = neg * (SIDE_WALL_X - self.rng.uniform(800, 1300))
z = 17
car = state_wrapper.cars[0]
car.set_pos(x,
y,
z,
)
car.set_rot(0, ((180 if neg == -1 else 0) + (neg * self.rng.uniform(10, 35))) * DEG_TO_RAD, 0)
speed = self.rng.uniform(400, 1000)
car.set_lin_vel(speed * neg, 0, 0)
car.boost = self.rng.uniform(0.2, 1.000001)

state_wrapper.ball.set_pos(x + (neg * 150), y + 125, 94)
state_wrapper.ball.set_lin_vel(1600 * neg, 400, 0)
state_wrapper.ball.set_ang_vel(0, 0, 0)


def mirror(car: CarWrapper, ball_x, ball_y):
my_car = namedtuple('my_car', 'pos lin_vel rot ang_vel')
if ball_x == ball_y == 0:
Expand Down
30 changes: 12 additions & 18 deletions mybots_terminals.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,31 +288,25 @@ def is_terminal(self, current_state: GameState) -> bool:
return True


class PlayerTouchGround(TerminalCondition):
class LixTrainer(TerminalCondition):

def __init__(self, dist_from_side_wall: int = -50, end_object: PhysicsObject = None, allow_boost_full_ground=False):
def __init__(self, tick_skip: int, dist_from_side_wall: int = -250, height: int = 1400, time_to_arm_sec: int = 3,
):
super().__init__()
self.allow_boost_full_ground = allow_boost_full_ground
self.end_object = end_object
self.height = height
self.time_to_arm_steps = time_to_arm_sec * (120 // tick_skip)
self.dist_from_side_wall = dist_from_side_wall
self.steps = 0

def reset(self, initial_state: GameState):
pass
self.steps = 0

def is_terminal(self, current_state: GameState) -> bool:
"""
return True if a player touches ground, with hacks for end object allowances
return True if a player has exceeded the allowances for a lix reset
"""
dist_limit_x = self.dist_from_side_wall
if self.end_object is not None and (
abs(current_state.ball.position[0]) == 3072 and abs(current_state.ball.position[1]) == 4096):
if self.allow_boost_full_ground:
return False
dist_limit_x = 1300 # allow reaching boost
elif self.end_object is not None and \
self.end_object.position[0] == self.end_object.position[1] == self.end_object.position[2] == -1:
return False

self.steps += 1
for i, player in enumerate(current_state.players):
if player.on_ground and player.car_data.position[2] < 22:
return (SIDE_WALL_X - abs(player.car_data.position[0])) > dist_limit_x
return (((SIDE_WALL_X - abs(player.car_data.position[0])) > self.dist_from_side_wall) or
player.car_data.position[2] > self.height) and self.steps > self.time_to_arm_steps

124 changes: 124 additions & 0 deletions worker_lix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import sys
from redis import Redis
from redis.retry import Retry
from redis.backoff import ExponentialBackoff
from redis.exceptions import ConnectionError, TimeoutError
from rlgym.envs import Match
from CoyoteObs import CoyoteObsBuilder
from rlgym.utils.terminal_conditions.common_conditions import GoalScoredCondition, TimeoutCondition, \
BallTouchedCondition
from mybots_terminals import LixTrainer
from rocket_learn.rollout_generator.redis.redis_rollout_worker import RedisRolloutWorker
from CoyoteParser import CoyoteAction
from rewards import ZeroSumReward
from torch import set_num_threads
from mybots_statesets import LixSetter
import Constants_lix
import os

set_num_threads(1)

if __name__ == "__main__":
frame_skip = Constants_lix.FRAME_SKIP
rew = ZeroSumReward(zero_sum=Constants_lix.ZERO_SUM,
velocity_pb_w=0.01,
wall_touch_w=0.5,
tick_skip=Constants_lix.FRAME_SKIP,
curve_wave_zap_dash_w=0.35,
walldash_w=0.35,
flip_reset_w=5,
)

fps = 120 // frame_skip
name = "Default"
send_gamestate = False
streamer_mode = False
local = True
auto_minimize = True
game_speed = 100
evaluation_prob = 0
past_version_prob = 0.1
deterministic_streamer = False
force_old_deterministic = False
gamemode_weights = {'1v1': 1, '2v2': 0, '3v3': 0}
team_size = 1
dynamic_game = False
host = "127.0.0.1"
if len(sys.argv) > 1:
host = sys.argv[1]
if host != "127.0.0.1" and host != "localhost":
local = False
if len(sys.argv) > 2:
name = sys.argv[2]
# if len(sys.argv) > 3 and not dynamic_game:
# team_size = int(sys.argv[3])
if len(sys.argv) > 3:
if sys.argv[3] == 'GAMESTATE':
send_gamestate = True
elif sys.argv[3] == 'STREAMER':
streamer_mode = True
evaluation_prob = 0
game_speed = 1
auto_minimize = False
# gamemode_weights = {'1v1': 1, '2v2': 0, '3v3': 0}

match = Match(
game_speed=game_speed,
spawn_opponents=False,
team_size=team_size,
state_setter=LixSetter(),
obs_builder=CoyoteObsBuilder(expanding=True,
tick_skip=Constants_lix.FRAME_SKIP,
team_size=3, extra_boost_info=False,
embed_players=False,
add_jumptime=True,
add_airtime=True,
add_fliptime=True,
add_boosttime=True,
add_handbrake=True),
action_parser=CoyoteAction("test_setter"),
terminal_conditions=[GoalScoredCondition(),
TimeoutCondition(fps * 10),
# TimeoutCondition(fps * 2),
LixTrainer(tick_skip=frame_skip)
],
reward_function=rew,
tick_skip=frame_skip,
)

# local Redis
if local:
r = Redis(host=host,
username="user1",
password=os.environ["redis_user1_key"],
db=Constants_lix.DB_NUM,
)

# remote Redis
else:
# noinspection PyArgumentList
r = Redis(host=host,
username="user1",
password=os.environ["redis_user1_key"],
retry_on_error=[ConnectionError, TimeoutError],
retry=Retry(ExponentialBackoff(cap=10, base=1), 25),
db=Constants_lix.DB_NUM,
)

RedisRolloutWorker(r, name, match,
past_version_prob=past_version_prob,
sigma_target=2,
evaluation_prob=evaluation_prob,
force_paging=False,
dynamic_gm=dynamic_game,
send_obs=True,
auto_minimize=auto_minimize,
send_gamestates=send_gamestate,
# gamemode_weights=gamemode_weights, # default 1/3
streamer_mode=streamer_mode,
deterministic_streamer=deterministic_streamer,
force_old_deterministic=force_old_deterministic,
# testing
batch_mode=True,
step_size=Constants_lix.STEP_SIZE,
).run()

0 comments on commit 3e79191

Please sign in to comment.