-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
First training session has been done, and the performance is quite well, it can start training in complex environment, and also test with dynamic obstacle.(model name:sac_turtlebot3_final.zip)
- Loading branch information
1 parent
7e7fa03
commit be8983a
Showing
5 changed files
with
161 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,39 @@ | ||
from stable_baselines3 import PPO | ||
import gym | ||
from stable_baselines3 import SAC | ||
from turtlebot3_rl.env import TurtleBot3Env | ||
import numpy as np | ||
|
||
def main(): | ||
# Initialize the environment | ||
env = TurtleBot3Env() | ||
|
||
# Load the trained model | ||
model = PPO.load("ppo_turtlebot3") | ||
model = SAC.load("sac_turtlebot3_final.zip") | ||
|
||
obs, _ = env.reset() | ||
for _ in range(20): # Test for a small number of steps | ||
action, _states = model.predict(obs, deterministic=True) | ||
obs, reward, done, truncated, info = env.step(action) | ||
print(f"Action: {action}, Min Distance: {np.min(obs)}, Reward: {reward}, Done: {done}") | ||
if done or truncated: | ||
obs, _ = env.reset() | ||
try: | ||
while True: # Infinite loop | ||
obs, _ = env.reset() # Unpack the observation from the tuple | ||
done = False | ||
total_reward = 0.0 | ||
|
||
env.close() | ||
while not done: | ||
# The model predicts an action based on the observation | ||
action, _states = model.predict(obs, deterministic=True) | ||
|
||
# Apply the action in the environment | ||
obs, reward, done, _, info = env.step(action) # Unpack all returned values | ||
|
||
# Accumulate the reward for this episode | ||
total_reward += reward | ||
|
||
print(f"Episode finished: Total Reward: {total_reward}") | ||
|
||
except KeyboardInterrupt: | ||
print("Testing interrupted by user. Exiting...") | ||
|
||
finally: | ||
# Close the environment properly | ||
env.close() | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,79 @@ | ||
import gym | ||
from stable_baselines3 import PPO | ||
'''import gym | ||
from stable_baselines3 import SAC | ||
from stable_baselines3.common.monitor import Monitor | ||
from stable_baselines3.common.env_checker import check_env | ||
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback | ||
from turtlebot3_rl.env import TurtleBot3Env | ||
import numpy as np | ||
np.bool = np.bool_ | ||
|
||
def main(): | ||
# Initialize the environment | ||
env = TurtleBot3Env() | ||
check_env(env) | ||
check_env(env) # Ensure that the environment adheres to the Stable Baselines3 API | ||
# Wrap the environment in a Monitor for logging | ||
env = Monitor(env) | ||
# Define the total number of timesteps for training | ||
timesteps = 1000000 # Define this before using it in the lambda function | ||
# Define the learning rate schedule | ||
learning_rate = lambda x: 0.0003 * (1 - x / timesteps) | ||
# Define the RL algorithm (SAC) with potential adjustments | ||
model = SAC("MlpPolicy", env, verbose=1, learning_rate=learning_rate, buffer_size=200000, learning_starts=1000, batch_size=128, tau=0.005, gamma=0.99, train_freq=4, gradient_steps=4, use_sde=True) | ||
# Callbacks for saving models and evaluation | ||
checkpoint_interval = 50000 # Define the interval for saving checkpoints | ||
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path='./models/', name_prefix='sac_turtlebot3') | ||
eval_callback = EvalCallback(env, best_model_save_path='./models/', log_path='./logs/', eval_freq=checkpoint_interval, deterministic=True) | ||
# Define the RL algorithm (PPO in this case) | ||
model = PPO("MlpPolicy", env, ent_coef=0.01, learning_rate=0.000001, n_steps=2048, batch_size=64, n_epochs=10, clip_range=0.1, verbose=1) | ||
# Train the model with checkpoints | ||
model.learn(total_timesteps=timesteps, reset_num_timesteps=False, callback=[checkpoint_callback, eval_callback]) | ||
# Train the model | ||
model.learn(total_timesteps=1000000)#50000 | ||
# Save the final model | ||
model.save("sac_turtlebot3_final") | ||
# Close the environment | ||
env.close() | ||
if __name__ == '__main__': | ||
main() | ||
''' | ||
import gym | ||
from stable_baselines3 import SAC | ||
from stable_baselines3.common.monitor import Monitor | ||
from stable_baselines3.common.env_checker import check_env | ||
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback | ||
from turtlebot3_rl.env import TurtleBot3Env | ||
|
||
def main(): | ||
# Initialize the environment | ||
env = TurtleBot3Env() | ||
check_env(env) # Ensure that the environment adheres to the Stable Baselines3 API | ||
|
||
# Save the model | ||
model.save("ppo_turtlebot3") | ||
# Wrap the environment in a Monitor for logging | ||
env = Monitor(env) | ||
|
||
# Define the total number of additional timesteps for training | ||
additional_timesteps = 500000 # Define how many more timesteps you want to train | ||
|
||
# Load the existing model | ||
model = SAC.load("sac_turtlebot3_final.zip", env=env) | ||
|
||
# Callbacks for saving models and evaluation during continued training | ||
checkpoint_interval = 50000 # Define the interval for saving checkpoints | ||
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path='./models/', name_prefix='sac_turtlebot3') | ||
eval_callback = EvalCallback(env, best_model_save_path='./models/', log_path='./logs/', eval_freq=checkpoint_interval, deterministic=True) | ||
|
||
# Continue training the model with additional timesteps | ||
model.learn(total_timesteps=additional_timesteps, reset_num_timesteps=False, callback=[checkpoint_callback, eval_callback]) | ||
|
||
# Save the updated model | ||
model.save("sac_turtlebot3_final_finetuned") | ||
|
||
# Close the environment | ||
env.close() | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|