From 975e02413e7e8d684e4828cee86dc6ad2deb63ff Mon Sep 17 00:00:00 2001 From: Josh Lloyd <834342+nevercast@users.noreply.github.com> Date: Sun, 8 Jan 2023 20:17:01 +1300 Subject: [PATCH] Send selector actions to Redis from Streamer --- Constants_selector.py | 27 ++++++++++++++++ CoyoteParser.py | 13 +++++++- learner_selector.py | 26 +--------------- selection_listener.py | 7 +++++ worker_selector.py | 72 +++++++++++++++++++++++++++---------------- 5 files changed, 93 insertions(+), 52 deletions(-) create mode 100644 selection_listener.py diff --git a/Constants_selector.py b/Constants_selector.py index dbb1786..0d950d7 100644 --- a/Constants_selector.py +++ b/Constants_selector.py @@ -5,3 +5,30 @@ STEP_SIZE = 400_000 DB_NUM = 7 STACK_SIZE = 20 +SELECTION_CHANNEL = "on_model_selection" + +SUB_MODEL_NAMES = [ + "kickoff_1", + "kickoff_2", + "GP", + "aerial", + "flick_bump", + "flick", + "flip_reset_1", + "flip_reset_2", + "flip_reset_3", + "pinch", + "recover_0", + "recover_-45", + "recover_-90", + "recover_-135", + "recover_180", + "recover_135", + "recover_90", + "recover_45", + "recover_back_left", + "recover_back_right", + "recover_opponent", + "recover_back_post", + "recover_ball", +] diff --git a/CoyoteParser.py b/CoyoteParser.py index 39258e7..8a2a9ea 100644 --- a/CoyoteParser.py +++ b/CoyoteParser.py @@ -9,6 +9,7 @@ from rlgym.utils.action_parsers import ActionParser from rlgym.utils.gamestates import GameState from CoyoteObs import CoyoteObsBuilder +from selection_listener import SelectionListener class CoyoteAction(ActionParser): @@ -294,10 +295,15 @@ def override_abs_state(player, state, position_index) -> GameState: retstate.players[i].inverted_car_data.position = oppo_pos return retstate - class SelectorParser(ActionParser): def __init__(self): from submodels.submodel_agent import SubAgent + from Constants_selector import SUB_MODEL_NAMES + self.sub_model_names = [ + name.replace("recover", "rush").replace("_", " ").title() + for name in SUB_MODEL_NAMES + ] + self.selection_listener = None super().__init__() self.models = [(SubAgent("kickoff_1_jit.pt"), CoyoteObsBuilder(expanding=True, tick_skip=4, team_size=3)), @@ -404,6 +410,8 @@ def parse_actions(self, actions: Any, state: GameState) -> np.ndarray: obs = self.models[action][1].build_obs( player, newstate, self.prev_actions[i]) parse_action = self.models[action][0].act(obs)[0] + if self.selection_listener is not None: + self.selection_listener.on_selection(self.sub_model_names[action], parse_action) # self.prev_action[i] = np.asarray(parse_action) self.prev_actions[i] = parse_action parsed_actions.append(parse_action) @@ -415,6 +423,9 @@ def reset(self, initial_state: GameState): for model in self.models: model[1].reset(initial_state) + def register_selection_listener(self, listener: SelectionListener): + self.selection_listener = listener + if __name__ == '__main__': ap = CoyoteAction() diff --git a/learner_selector.py b/learner_selector.py index 3ce30bc..674b130 100644 --- a/learner_selector.py +++ b/learner_selector.py @@ -139,31 +139,7 @@ print(f"Gamma is: {gamma}") count_parameters(agent) - action_dict = {0: "kickoff_1", - 1: "kickoff_2", - 2: "GP", - 3: "aerial", - 4: "flick_bump", - 5: "flick", - 6: "flip_reset_1", - 7: "flip_reset_2", - 8: "flip_reset_3", - 9: "pinch", - 10: "recover_0", - 11: "recover_-45", - 12: "recover_-90", - 13: "recover_-135", - 14: "recover_180", - 15: "recover_135", - 16: "recover_90", - 17: "recover_45", - 18: "recover_back_left", - 19: "recover_back_right", - 20: "recover_opponent", - 21: "recover_back_post", - 22: "recover_ball", - } - + action_dict = { i: k for i,k in enumerate(Constants_selector.SUB_MODEL_NAMES) } alg = PPO( rollout_gen, agent, diff --git a/selection_listener.py b/selection_listener.py new file mode 100644 index 0000000..07730d7 --- /dev/null +++ b/selection_listener.py @@ -0,0 +1,7 @@ +from abc import ABC, abstractmethod +import numpy as np + +class SelectionListener(ABC): + @abstractmethod + def on_selection(self, selected_model_name: str, model_action: np.ndarray): + pass diff --git a/worker_selector.py b/worker_selector.py index 7838812..49ed16f 100644 --- a/worker_selector.py +++ b/worker_selector.py @@ -11,12 +11,25 @@ from CoyoteParser import SelectorParser from rewards import ZeroSumReward from torch import set_num_threads +from selection_listener import SelectionListener from setter import CoyoteSetter import Constants_selector +import numpy as np +import json import os set_num_threads(1) +class SelectionDispatcher(SelectionListener): + def __init__(self, redis, redis_channel) -> None: + super().__init__() + self.redis = redis + self.redis_channel = redis_channel + + def on_selection(self, selected_model_name: str, model_action: np.ndarray): + selection_message = dict(model=selected_model_name, actions=model_action.tolist()) + self.redis.publish(self.redis_channel, json.dumps(selection_message)) + if __name__ == "__main__": rew = ZeroSumReward(zero_sum=Constants_selector.ZERO_SUM, goal_w=5, @@ -53,23 +66,49 @@ dynamic_game = True infinite_boost_odds = 0 host = "127.0.0.1" + if len(sys.argv) > 1: host = sys.argv[1] if host != "127.0.0.1" and host != "localhost": local = False + if len(sys.argv) > 2: name = sys.argv[2] - # if len(sys.argv) > 3 and not dynamic_game: - # team_size = int(sys.argv[3]) + + # local Redis + if local: + r = Redis(host=host, + username="user1", + password=os.environ["redis_user1_key"], + db=Constants_selector.DB_NUM, + ) + + # remote Redis + else: + # noinspection PyArgumentList + r = Redis(host=host, + username="user1", + password=os.environ["redis_user1_key"], + retry_on_error=[ConnectionError, TimeoutError], + retry=Retry(ExponentialBackoff(cap=10, base=1), 25), + db=Constants_selector.DB_NUM, + ) + + def setup_streamer(): + global game_speed, evaluation_prob, past_version_prob, auto_minimize, infinite_boost_odds, streamer_mode + streamer_mode = True + evaluation_prob = 0 + game_speed = 1 + auto_minimize = False + infinite_boost_odds = 0 + dispatcher = SelectionDispatcher(r, Constants_selector.SELECTION_CHANNEL) + parser.register_selection_listener(dispatcher) + if len(sys.argv) > 3: if sys.argv[3] == 'GAMESTATE': send_gamestate = True elif sys.argv[3] == 'STREAMER': - streamer_mode = True - evaluation_prob = 0 - game_speed = 1 - auto_minimize = False - infinite_boost_odds = 0 + setup_streamer() match = Match( game_speed=game_speed, @@ -89,25 +128,6 @@ tick_skip=frame_skip, ) - # local Redis - if local: - r = Redis(host=host, - username="user1", - password=os.environ["redis_user1_key"], - db=Constants_selector.DB_NUM, - ) - - # remote Redis - else: - # noinspection PyArgumentList - r = Redis(host=host, - username="user1", - password=os.environ["redis_user1_key"], - retry_on_error=[ConnectionError, TimeoutError], - retry=Retry(ExponentialBackoff(cap=10, base=1), 25), - db=Constants_selector.DB_NUM, - ) - worker = RedisRolloutWorker(r, name, match, past_version_prob=past_version_prob, sigma_target=2,