From 95c7a448be6f86920be5e866f70a6f02ce573249 Mon Sep 17 00:00:00 2001 From: Kaiyotech <93724202+Kaiyotech@users.noreply.github.com> Date: Thu, 29 Jun 2023 11:20:00 -0400 Subject: [PATCH] beginning of action group graphing but not finished --- learner_selector.py | 10 +++++++--- my_stattrackers.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/learner_selector.py b/learner_selector.py index 4eecaf0..5ba7914 100644 --- a/learner_selector.py +++ b/learner_selector.py @@ -23,7 +23,7 @@ from rocket_learn.utils.stat_trackers.common_trackers import Speed, Demos, TimeoutRate, Touch, EpisodeLength, Boost, \ BehindBall, TouchHeight, DistToBall, AirTouch, AirTouchHeight, BallHeight, BallSpeed, CarOnGround, GoalSpeed, \ MaxGoalSpeed -from my_stattrackers import GoalSpeedTop5perc, FlipReset +from my_stattrackers import GoalSpeedTop5perc, FlipReset, ActionGroupingTracker from rlgym.utils.reward_functions.common_rewards import VelocityReward, EventReward from rlgym.utils.reward_functions.combined_reward import CombinedReward @@ -55,10 +55,10 @@ ent_coef=0.03, ) - run_id = "selector_run_21.01" + run_id = "selector_run_test21.01" wandb.login(key=os.environ["WANDB_KEY"]) logger = wandb.init(dir="./wandb_store", - name="Selector_Run_21.01", + name="Selector_Run_test21.01", project="Opti", entity="kaiyotech", id=run_id, @@ -147,6 +147,10 @@ logger=logger, clear=False, stat_trackers=stat_trackers, + # action_grouping_tracker=ActionGroupingTracker(aerial_indices=[3, 6, 7, 8, 28, 29], + # wall_indices=[8, 25, 26, 28, 29], + # ground_indices=[0, 1, 2, 4, 5, *range(9, 25), 27, 29], + # defend_indices=[3, 6, 7, 8, 28]), # gamemodes=("1v1", "2v2", "3v3"), max_age=1, pretrained_agents=Constants_selector.pretrained_agents diff --git a/my_stattrackers.py b/my_stattrackers.py index dd6c2bf..20f8c97 100644 --- a/my_stattrackers.py +++ b/my_stattrackers.py @@ -97,3 +97,42 @@ def update(self, gamestates: np.ndarray, mask: np.ndarray): def get_stat(self): return self.flip_reset_count / (self.count or 1) + + +class ActionGroupingTracker: + # check if actions were used at the appropriate time + + def __init__(self, aerial_indices, ground_indices, defend_indices, wall_indices): + self.name = "Action Grouping Tracker" + self.count = 0 + self.flip_reset_count = 0 + + def reset(self): + self.count = 0 + self.flip_reset_count = 0 + + def update(self, gamestates: np.ndarray, mask: np.ndarray): + players = gamestates[:, StateConstants.PLAYERS] + num_players = len(players[0]) // 39 + has_jumps = players[:, StateConstants.HAS_JUMP] + # on_grounds = players[:, StateConstants.ON_GROUND] + players_x = players[:, StateConstants.CAR_POS_X] + players_y = players[:, StateConstants.CAR_POS_Y] + players_z = players[:, StateConstants.CAR_POS_Z] + for i in range(num_players): + has_jumps_player = has_jumps[:, i] + changes = np.where(has_jumps_player[:1] < has_jumps_player[:-1], True, False) + player_x = players_x[:, i] + player_y = players_y[:, i] + player_z = players_z[:, i] + on_grounds_player = np.where((player_z < 300) | (player_z > CEILING_Z - 300) | + ((-SIDE_WALL_X + 700) > player_x) | + ((SIDE_WALL_X - 700) > player_x) | + ((-BACK_WALL_Y + 700) > player_y) | + ((BACK_WALL_Y - 700) > player_y), True, False) + self.flip_reset_count += (~on_grounds_player[1:] & changes).sum() + + self.count += has_jumps.size + + def get_stat(self): + return self.flip_reset_count / (self.count or 1) \ No newline at end of file