Skip to content

Commit

Permalink
Added new reward method, and tried different training vars
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaden505 committed Jul 24, 2023
1 parent 2788fe3 commit 91db820
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 13 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,8 @@ dmypy.json

# PyCharm
.idea/
__pycache__/
__pycache__/

models
./models/
models/
18 changes: 13 additions & 5 deletions cube/helper_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,22 @@ def step(self, action):
#
# return reward + solved_face

def reward_action(self, state, next_state):
reward = self.reward_color_count(next_state) - self.reward_color_count(state)
# def reward_action(self, state, next_state):
# reward = self.reward_color_count(next_state) - self.reward_color_count(state)
#
# if (self.reward_face_solved(state, next_state)) > 0:
# reward = max(reward + 0.4, 1)
#
# if self.check_solved():
# reward = 1
#
# return reward

if (self.reward_face_solved(state, next_state) / 6) > 0:
reward = max(reward + 0.4, 1)
def reward_action(self, state, next_state):
reward = CubeHelper.reward_face_solved(state, next_state) / 6

if self.check_solved():
reward = 1
reward = 10

return reward

Expand Down
2 changes: 0 additions & 2 deletions solver/dqn_agent.py → dqn/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def __init__(self):
self.rotation_dict = {0: "U", 1: "U'", 2: "D", 3: "D'", 4: "L", 5: "L'",
6: "R", 7: "R'", 8: "F", 9: "F'", 10: "B", 11: "B'"}

# self.prev_pred = None

def create_model(self):
"""
Creates the model for the neural network
Expand Down
9 changes: 6 additions & 3 deletions solver/main.py → dqn/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from cube.helper_cube import CubeHelper
from solver.dqn_agent import DqnAgent
from solver.replay_buffer import ReplayBuffer
from dqn.dqn_agent import DqnAgent
from dqn.replay_buffer import ReplayBuffer

import copy

Expand All @@ -16,6 +16,8 @@ def __init__(self):
self.TARGET_UPDATE = 5
self.UPDATE_ALL_TD = 2

self.model_save_path = "../models/model.h5"

def train_model(self):
for step in range(self.STEPS):
self.get_train_data()
Expand All @@ -29,8 +31,9 @@ def train_model(self):

if step % self.TARGET_UPDATE == 0:
self.agent.update_target_model()
self.agent.model.save(self.model_save_path)

self.agent.model.save("../models/model1.h5")
self.agent.model.save(self.model_save_path)

def get_train_data(self):
self.cube.scramble()
Expand Down
File renamed without changes.
3 changes: 1 addition & 2 deletions solver/test_predict.py → dqn/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

cube = CubeHelper()
agent = DqnAgent()
model = load_model('../models/model1.h5')
model = load_model('../models/model.h5')

agent.model = model
[]
def try_solve():
cube.scramble()
state = cube.get_cube_state()
Expand Down

0 comments on commit 91db820

Please sign in to comment.