Skip to content

Commit

Permalink
demo restart with pretrained on one half of teams
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyotech committed May 8, 2023
1 parent 619baed commit fb9fc39
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 6 deletions.
10 changes: 10 additions & 0 deletions Constants_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,13 @@
ZERO_SUM = False
STEP_SIZE = 500_000
DB_NUM = 14

model_name = "nexto-model.pt"
nexto = NextoV2(model_string=model_name, n_players=6)
model_name = "kbb.pt"
kbb = KBB(model_string=model_name)

pretrained_agents = {
nexto: {'prob': 0.5, 'eval': True, 'p_deterministic_training': 1., 'key': "Nexto"},
kbb: {'prob': 0.5, 'eval': True, 'p_deterministic_training': 1., 'key': "KBB"}
}
3 changes: 2 additions & 1 deletion learner_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
stat_trackers=stat_trackers,
gamemodes=("1v1", "2v2", "3v3"),
max_age=1,
pretrained_agents=Constants_demo.pretrained_agents,
)

critic = Sequential(Linear(222, 256), LeakyReLU(), Linear(256, 128), LeakyReLU(),
Expand Down Expand Up @@ -133,7 +134,7 @@
disable_gradient_logging=True,
)

alg.load("Demo_saves/Opti_1683473991.4737124/Opti_5080/checkpoint.pt")
# alg.load("Demo_saves/Opti_1683473991.4737124/Opti_5080/checkpoint.pt")

alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr
alg.agent.optimizer.param_groups[1]["lr"] = logger.config.critic_lr
Expand Down
39 changes: 38 additions & 1 deletion my_matchmaker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from rocket_learn.matchmaker.base_matchmaker import BaseMatchmaker
from trueskill import Rating

from rocket_learn.agent.types import PretrainedAgents
import numpy as np

class MatchmakerWith1v0(BaseMatchmaker):

Expand Down Expand Up @@ -28,3 +29,39 @@ def generate_matchup(self, redis, n_agents, evaluate):
latest_version = -1

return [latest_version] * n_agents, [self.rating] * n_agents, False, n_agents // 2, n_agents // 2


class MatchmakerFullVPretrained(BaseMatchmaker):

def __init__(self, pretrained_agents: PretrainedAgents = None):
if pretrained_agents is not None:
self.consider_pretrained = True
pretrained_agents_keys, pretrained_agents_values = zip(
*pretrained_agents.items())
self.pretrained_agents = pretrained_agents_keys
pretrained_probs = [p["prob"] for p in pretrained_agents_values]
self.pretrained_probs = np.array(
pretrained_probs) / sum(pretrained_probs)
self.pretrained_evals = [p["eval"]
for p in pretrained_agents_values]
self.pretrained_p_deterministic_training = [
p["p_deterministic_training"] if p["p_deterministic_training"] is not None else 1 for p in pretrained_agents_values]
self.pretrained_keys = [p["key"] for p in pretrained_agents_values]
self.pretrained_eval_keys = [k for i, k in enumerate(
self.pretrained_keys) if self.pretrained_evals[i]]
self.rating = Rating(0, 1)

def generate_matchup(self, redis, n_agents, evaluate):

# Doing this instead of int(redis.get(VERSION_LATEST)) because the latest model is
# whatever is currently in rollout worker
latest_version = -1
versions = [latest_version] * (n_agents // 2)

n_each_pretrained = np.random.multinomial(
1, self.pretrained_probs)
pretrained_idx = list(n_each_pretrained).index(1)
pretrained_key = self.pretrained_keys[pretrained_idx]
versions += [pretrained_key] * (n_agents // 2)

return versions, [self.rating] * n_agents, False, n_agents // 2, n_agents // 2
11 changes: 7 additions & 4 deletions worker_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from CoyoteObs import CoyoteObsBuilder
from rocket_learn.rollout_generator.redis.redis_rollout_worker import RedisRolloutWorker
from my_matchmaker import MatchmakerSimple
from my_matchmaker import MatchmakerFullVPretrained
from CoyoteParser import CoyoteAction
from rewards import ZeroSumReward
from torch import set_num_threads
Expand Down Expand Up @@ -43,7 +43,7 @@
0] # [0.825, 0.0826, 0.0578, 0.0346] # this includes past_version and pretrained
deterministic_streamer = True
force_old_deterministic = False
gamemode_weights = {'1v1': 0.4, '2v2': 0.35, '3v3': 0.25}
gamemode_weights = {'1v1': 0.8, '2v2': 0.1, '3v3': 0.1}
visualize = False
simulator = True
batch_mode = True
Expand All @@ -53,7 +53,9 @@
host = "127.0.0.1"
epic_rl_exe_path = None # "D:/Program Files/Epic Games/rocketleague_old/Binaries/Win64/RocketLeague.exe"

matchmaker = MatchmakerSimple()
pretrained_agents = Constants_demo.pretrained_agents

matchmaker = MatchmakerFullVPretrained(pretrained_agents=pretrained_agents)

if len(sys.argv) > 1:
host = sys.argv[1]
Expand Down Expand Up @@ -161,7 +163,8 @@
simulator=simulator,
visualize=visualize,
live_progress=False,
tick_skip=Constants_demo.FRAME_SKIP
tick_skip=Constants_demo.FRAME_SKIP,
pretrained_agents=pretrained_agents,
)

worker.env._match._obs_builder.env = worker.env # noqa
Expand Down

0 comments on commit fb9fc39

Please sign in to comment.