Skip to content

Commit

Permalink
Send selector actions to Redis from Streamer
Browse files Browse the repository at this point in the history
  • Loading branch information
nevercast committed Jan 8, 2023
1 parent db92196 commit 975e024
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 52 deletions.
27 changes: 27 additions & 0 deletions Constants_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,30 @@
STEP_SIZE = 400_000
DB_NUM = 7
STACK_SIZE = 20
SELECTION_CHANNEL = "on_model_selection"

SUB_MODEL_NAMES = [
"kickoff_1",
"kickoff_2",
"GP",
"aerial",
"flick_bump",
"flick",
"flip_reset_1",
"flip_reset_2",
"flip_reset_3",
"pinch",
"recover_0",
"recover_-45",
"recover_-90",
"recover_-135",
"recover_180",
"recover_135",
"recover_90",
"recover_45",
"recover_back_left",
"recover_back_right",
"recover_opponent",
"recover_back_post",
"recover_ball",
]
13 changes: 12 additions & 1 deletion CoyoteParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rlgym.utils.action_parsers import ActionParser
from rlgym.utils.gamestates import GameState
from CoyoteObs import CoyoteObsBuilder
from selection_listener import SelectionListener


class CoyoteAction(ActionParser):
Expand Down Expand Up @@ -294,10 +295,15 @@ def override_abs_state(player, state, position_index) -> GameState:
retstate.players[i].inverted_car_data.position = oppo_pos
return retstate


class SelectorParser(ActionParser):
def __init__(self):
from submodels.submodel_agent import SubAgent
from Constants_selector import SUB_MODEL_NAMES
self.sub_model_names = [
name.replace("recover", "rush").replace("_", " ").title()
for name in SUB_MODEL_NAMES
]
self.selection_listener = None
super().__init__()

self.models = [(SubAgent("kickoff_1_jit.pt"), CoyoteObsBuilder(expanding=True, tick_skip=4, team_size=3)),
Expand Down Expand Up @@ -404,6 +410,8 @@ def parse_actions(self, actions: Any, state: GameState) -> np.ndarray:
obs = self.models[action][1].build_obs(
player, newstate, self.prev_actions[i])
parse_action = self.models[action][0].act(obs)[0]
if self.selection_listener is not None:
self.selection_listener.on_selection(self.sub_model_names[action], parse_action)
# self.prev_action[i] = np.asarray(parse_action)
self.prev_actions[i] = parse_action
parsed_actions.append(parse_action)
Expand All @@ -415,6 +423,9 @@ def reset(self, initial_state: GameState):
for model in self.models:
model[1].reset(initial_state)

def register_selection_listener(self, listener: SelectionListener):
self.selection_listener = listener


if __name__ == '__main__':
ap = CoyoteAction()
Expand Down
26 changes: 1 addition & 25 deletions learner_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,31 +139,7 @@
print(f"Gamma is: {gamma}")
count_parameters(agent)

action_dict = {0: "kickoff_1",
1: "kickoff_2",
2: "GP",
3: "aerial",
4: "flick_bump",
5: "flick",
6: "flip_reset_1",
7: "flip_reset_2",
8: "flip_reset_3",
9: "pinch",
10: "recover_0",
11: "recover_-45",
12: "recover_-90",
13: "recover_-135",
14: "recover_180",
15: "recover_135",
16: "recover_90",
17: "recover_45",
18: "recover_back_left",
19: "recover_back_right",
20: "recover_opponent",
21: "recover_back_post",
22: "recover_ball",
}

action_dict = { i: k for i,k in enumerate(Constants_selector.SUB_MODEL_NAMES) }
alg = PPO(
rollout_gen,
agent,
Expand Down
7 changes: 7 additions & 0 deletions selection_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from abc import ABC, abstractmethod
import numpy as np

class SelectionListener(ABC):
@abstractmethod
def on_selection(self, selected_model_name: str, model_action: np.ndarray):
pass
72 changes: 46 additions & 26 deletions worker_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,25 @@
from CoyoteParser import SelectorParser
from rewards import ZeroSumReward
from torch import set_num_threads
from selection_listener import SelectionListener
from setter import CoyoteSetter
import Constants_selector
import numpy as np
import json
import os

set_num_threads(1)

class SelectionDispatcher(SelectionListener):
def __init__(self, redis, redis_channel) -> None:
super().__init__()
self.redis = redis
self.redis_channel = redis_channel

def on_selection(self, selected_model_name: str, model_action: np.ndarray):
selection_message = dict(model=selected_model_name, actions=model_action.tolist())
self.redis.publish(self.redis_channel, json.dumps(selection_message))

if __name__ == "__main__":
rew = ZeroSumReward(zero_sum=Constants_selector.ZERO_SUM,
goal_w=5,
Expand Down Expand Up @@ -53,23 +66,49 @@
dynamic_game = True
infinite_boost_odds = 0
host = "127.0.0.1"

if len(sys.argv) > 1:
host = sys.argv[1]
if host != "127.0.0.1" and host != "localhost":
local = False

if len(sys.argv) > 2:
name = sys.argv[2]
# if len(sys.argv) > 3 and not dynamic_game:
# team_size = int(sys.argv[3])

# local Redis
if local:
r = Redis(host=host,
username="user1",
password=os.environ["redis_user1_key"],
db=Constants_selector.DB_NUM,
)

# remote Redis
else:
# noinspection PyArgumentList
r = Redis(host=host,
username="user1",
password=os.environ["redis_user1_key"],
retry_on_error=[ConnectionError, TimeoutError],
retry=Retry(ExponentialBackoff(cap=10, base=1), 25),
db=Constants_selector.DB_NUM,
)

def setup_streamer():
global game_speed, evaluation_prob, past_version_prob, auto_minimize, infinite_boost_odds, streamer_mode
streamer_mode = True
evaluation_prob = 0
game_speed = 1
auto_minimize = False
infinite_boost_odds = 0
dispatcher = SelectionDispatcher(r, Constants_selector.SELECTION_CHANNEL)
parser.register_selection_listener(dispatcher)

if len(sys.argv) > 3:
if sys.argv[3] == 'GAMESTATE':
send_gamestate = True
elif sys.argv[3] == 'STREAMER':
streamer_mode = True
evaluation_prob = 0
game_speed = 1
auto_minimize = False
infinite_boost_odds = 0
setup_streamer()

match = Match(
game_speed=game_speed,
Expand All @@ -89,25 +128,6 @@
tick_skip=frame_skip,
)

# local Redis
if local:
r = Redis(host=host,
username="user1",
password=os.environ["redis_user1_key"],
db=Constants_selector.DB_NUM,
)

# remote Redis
else:
# noinspection PyArgumentList
r = Redis(host=host,
username="user1",
password=os.environ["redis_user1_key"],
retry_on_error=[ConnectionError, TimeoutError],
retry=Retry(ExponentialBackoff(cap=10, base=1), 25),
db=Constants_selector.DB_NUM,
)

worker = RedisRolloutWorker(r, name, match,
past_version_prob=past_version_prob,
sigma_target=2,
Expand Down

0 comments on commit 975e024

Please sign in to comment.