Skip to content

Commit

Permalink
need big simplifications...
Browse files Browse the repository at this point in the history
- state: hand needs to be reduced to binary per numeral/face class
- actions needs to be intent for movement and not exact movements - more "coaching"
  • Loading branch information
r3w0p committed Jan 6, 2025
1 parent 28343bb commit d9239d3
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 43 deletions.
2 changes: 1 addition & 1 deletion include/caravan/core/training.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

const uint16_t SIZE_ACTION = 5;
const uint16_t SIZE_ACTION_SPACE = 920;
const uint16_t SIZE_GAME_STATE = 200;
const uint16_t SIZE_GAME_STATE = 38;

const uint8_t NUM_PLAYER_ABC = 1;
const uint8_t NUM_PLAYER_DEF = 2;
Expand Down
57 changes: 23 additions & 34 deletions src/caravan/core/training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ uint8_t suit_to_uint8_t(Suit s) {
return static_cast<uint8_t>(s);
}

uint8_t direction_to_uint8_t(Direction d) {
// ANY is 0
return static_cast<uint8_t>(d);
}

void add_hand_to_game_state(GameState *gs, uint16_t *i_gs, Player *player) {
Hand hand = player->get_hand();
uint8_t hand_size = player->get_size_hand();
Expand All @@ -62,36 +67,28 @@ void add_hand_to_game_state(GameState *gs, uint16_t *i_gs, Player *player) {
}
}

void add_caravan_to_game_state(GameState *gs, uint16_t *i_gs, Caravan *caravan) {
// Current highest numeral position along caravan track
uint8_t max_track = caravan->get_size();
void add_caravan_to_game_state(GameState *gs, uint16_t *i_gs, Game *game, Caravan *caravan) {
uint8_t caravan_size = caravan->get_size();

// Add whether caravan is winning
(*gs)[(*i_gs)++] = game->is_caravan_winning(caravan->get_name());

// Add whether numeral track is full
(*gs)[(*i_gs)++] = caravan_size == TRACK_NUMERIC_MAX;

// Add caravan direction
(*gs)[(*i_gs)++] = direction_to_uint8_t(caravan->get_direction());

// Add caravan suit
(*gs)[(*i_gs)++] = suit_to_uint8_t(caravan->get_suit());

for (uint8_t i_track = 0; i_track < TRACK_NUMERIC_MAX; i_track++) {
// If numeral at track position, fetch slot state
if (i_track < max_track) {
Slot slot = caravan->get_slot(i_track + 1);

// Add numeral
(*gs)[(*i_gs)++] = rank_to_uint8_t(slot.card.rank);
// Add rank of highest numeral
if (caravan_size > 0) {
Slot slot = caravan->get_slot(caravan->get_size());
(*gs)[(*i_gs)++] = rank_to_uint8_t(slot.card.rank);

// Add face cards
for (uint8_t i_face = 0; i_face < TRACK_FACE_MAX; i_face++) {
if (i_face < slot.i_faces) {
(*gs)[(*i_gs)++] = rank_to_uint8_t(slot.faces[i_face].rank);
} else {
(*gs)[(*i_gs)++] = 0;
}
}
} else {
// No populated slot at caravan position, leave blank spaces
// for numeral and max face cards
for (uint8_t _ = 0; _ < (1 + TRACK_FACE_MAX); _++) {
(*gs)[(*i_gs)++] = 0;
}
}
} else {
(*gs)[(*i_gs)++] = 0;
}
}

Expand Down Expand Up @@ -126,7 +123,7 @@ void get_game_state(GameState *gs, Game *game, PlayerName pname) {
// Add state of each caravan, player's caravans first
for (uint8_t i_cvn = 0; i_cvn < cvn_names_size; i_cvn++) {
Caravan *caravan = table->get_caravan(cvn_names[i_cvn]);
add_caravan_to_game_state(gs, &i_gs, caravan);
add_caravan_to_game_state(gs, &i_gs, game, caravan);
}
}

Expand Down Expand Up @@ -349,8 +346,6 @@ void train_on_game(Game *game, QTable &q_table, ActionSpace &action_space, Train

// Otherwise, pick the optimal action from the q-table
action = action_pool[action_index];

printf("- %llu\n", q_table[gs].size());
}

// Generate input from action
Expand Down Expand Up @@ -398,12 +393,6 @@ void train_on_game(Game *game, QTable &q_table, ActionSpace &action_space, Train
}

q_table[last_gs][last_action] = q_table[last_gs][last_action] + tc.learning * (tc.discount * q_table[gs][action] - q_table[last_gs][last_action]);
/*
if (game->get_winner_name() != NO_PLAYER) {
//printf("%f\n", q_table[gs][action]);
printf("%f\n", q_table[last_gs][last_action]);
}
*/
}

// Log last move
Expand Down
13 changes: 7 additions & 6 deletions src/caravan/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ int main(int argc, char *argv[]) {
// Training parameters TODO user-defined arguments
float discount = 0.95;
float learning = 0.7;
uint32_t episode_max = 1000000;
uint32_t episode_max = 50000;

// Game config uses largest deck with most samples and balance to
// maximise chance of encountering every player hand combination.
Expand All @@ -54,11 +54,6 @@ int main(int argc, char *argv[]) {
};

for(; tc.episode <= tc.episode_max; tc.episode++) {
if (tc.episode % 100 == 0) {
printf("Episode %d\n", tc.episode);
printf("- states: %llu\n", q_table.size());
}

// Random first player
rand_first = dist_first_player(gen);
gc.player_first = rand_first == NUM_PLAYER_ABC ?
Expand All @@ -74,6 +69,12 @@ int main(int argc, char *argv[]) {

tc.learning = learning;

if (tc.episode % 1000 == 0) {
printf("Episode %d\n", tc.episode);
printf("- explore: %.2f\n", tc.explore);
printf("- states: %llu\n", q_table.size());
}

// Start a new game
game.reset(new Game(&gc));

Expand Down
4 changes: 2 additions & 2 deletions test/caravan/model/test_game.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ TEST (TestGame, GetPlayerTurn) {
};
Game g{&gc};

ASSERT_EQ(g.get_player_turn(), PLAYER_ABC);
ASSERT_EQ(g.get_player_turn()->get_name(), PLAYER_ABC);
}

TEST (TestGame, GetWinner_NoMoves) {
Expand All @@ -59,7 +59,7 @@ TEST (TestGame, GetWinner_NoMoves) {
};
Game g{&gc};

ASSERT_EQ(g.get_winner(), NO_PLAYER);
ASSERT_EQ(g.get_winner_name(), NO_PLAYER);
}

TEST (TestGame, PlayOption_Error_StartRound_Remove) {
Expand Down

0 comments on commit d9239d3

Please sign in to comment.