-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
In this update, i complete RL training, and testing, but the performance is still not good enough.
- Loading branch information
1 parent
e430708
commit 7e7fa03
Showing
12 changed files
with
302 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
<?xml version="1.0"?> | ||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?> | ||
<package format="3"> | ||
<name>turtlebot3_rl</name> | ||
<version>0.0.0</version> | ||
<description>TODO: Package description</description> | ||
<maintainer email="stanleychueh28@gmail.com">stanley</maintainer> | ||
<license>TODO: License declaration</license> | ||
|
||
<test_depend>ament_copyright</test_depend> | ||
<test_depend>ament_flake8</test_depend> | ||
<test_depend>ament_pep257</test_depend> | ||
<test_depend>python3-pytest</test_depend> | ||
|
||
<export> | ||
<build_type>ament_python</build_type> | ||
</export> | ||
</package> |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[develop] | ||
script-dir=$base/lib/turtlebot3_rl | ||
[install] | ||
install-scripts=$base/lib/turtlebot3_rl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from setuptools import setup | ||
|
||
package_name = 'turtlebot3_rl' | ||
|
||
setup( | ||
name=package_name, | ||
version='0.0.0', | ||
packages=[package_name], | ||
data_files=[ | ||
('share/ament_index/resource_index/packages', | ||
['resource/' + package_name]), | ||
('share/' + package_name, ['package.xml']), | ||
], | ||
install_requires=['setuptools'], | ||
zip_safe=True, | ||
maintainer='stanley', | ||
maintainer_email='stanley@example.com', | ||
description='RL package for TurtleBot3 in ROS2', | ||
license='Apache License 2.0', | ||
tests_require=['pytest'], | ||
entry_points={ | ||
'console_scripts': [ | ||
'train = turtlebot3_rl.train:main', | ||
'test = turtlebot3_rl.test:main', | ||
], | ||
}, | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright 2015 Open Source Robotics Foundation, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from ament_copyright.main import main | ||
import pytest | ||
|
||
|
||
@pytest.mark.copyright | ||
@pytest.mark.linter | ||
def test_copyright(): | ||
rc = main(argv=['.', 'test']) | ||
assert rc == 0, 'Found errors' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright 2017 Open Source Robotics Foundation, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from ament_flake8.main import main_with_errors | ||
import pytest | ||
|
||
|
||
@pytest.mark.flake8 | ||
@pytest.mark.linter | ||
def test_flake8(): | ||
rc, errors = main_with_errors(argv=[]) | ||
assert rc == 0, \ | ||
'Found %d code style errors / warnings:\n' % len(errors) + \ | ||
'\n'.join(errors) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright 2015 Open Source Robotics Foundation, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from ament_pep257.main import main | ||
import pytest | ||
|
||
|
||
@pytest.mark.linter | ||
@pytest.mark.pep257 | ||
def test_pep257(): | ||
rc = main(argv=['.', 'test']) | ||
assert rc == 0, 'Found code style errors / warnings' |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import gymnasium as gym | ||
from gymnasium import spaces | ||
import numpy as np | ||
import rclpy | ||
from rclpy.node import Node | ||
from rclpy.executors import MultiThreadedExecutor | ||
from geometry_msgs.msg import Twist | ||
from sensor_msgs.msg import LaserScan | ||
import threading | ||
|
||
class TurtleBot3Env(gym.Env): | ||
def __init__(self): | ||
super(TurtleBot3Env, self).__init__() | ||
|
||
# Initialize ROS2 node | ||
rclpy.init(args=None) | ||
self.node = rclpy.create_node('turtlebot3_rl_env') | ||
|
||
# Action space: linear and angular velocity | ||
# Reduced linear velocity for more cautious movement | ||
self.action_space = spaces.Box(low=np.array([-0.1, -0.05]), high=np.array([0.1, 0.05]), dtype=np.float32) | ||
|
||
# Observation space: laser scan data | ||
self.observation_space = spaces.Box(low=0.0, high=3.5, shape=(360,), dtype=np.float32) | ||
|
||
# Publishers and subscribers | ||
self.cmd_vel_pub = self.node.create_publisher(Twist, 'cmd_vel', 10) | ||
self.scan_sub = self.node.create_subscription( | ||
LaserScan, | ||
'/scan', | ||
self.scan_callback, | ||
10 # Queue size | ||
) | ||
|
||
# Initialize scan data | ||
self.scan_data = np.zeros(360, dtype=np.float32) | ||
|
||
# Use the executor to manage the node spinning | ||
self.executor = MultiThreadedExecutor() | ||
self.executor.add_node(self.node) | ||
|
||
# Start a separate thread to spin the ROS2 node | ||
self.spin_thread = threading.Thread(target=self.executor.spin) | ||
self.spin_thread.start() | ||
|
||
def scan_callback(self, msg): | ||
# Handle laser scan data | ||
raw_scan_data = np.array(msg.ranges, dtype=np.float32) | ||
|
||
# Replace 0.0 values with a large number, indicating 'no detection' | ||
raw_scan_data[raw_scan_data == 0.0] = np.inf | ||
|
||
# Ensure min range reflects realistic sensor limits | ||
self.scan_data = np.clip(raw_scan_data, 0.12, 3.5) | ||
|
||
def reset(self, seed=None, options=None): | ||
# Reset the environment | ||
self.node.get_logger().info("Resetting the environment...") | ||
|
||
if seed is not None: | ||
self.seed_val = seed | ||
self._np_random, seed = gym.utils.seeding.np_random(seed) | ||
|
||
return self.scan_data.astype(np.float32), {} | ||
|
||
def step(self, action): | ||
twist = Twist() | ||
twist.linear.x = float(action[0]) | ||
twist.angular.z = float(action[1]) | ||
self.cmd_vel_pub.publish(twist) | ||
|
||
# Log scan data for debugging | ||
self.node.get_logger().info(f"Current scan data: {self.scan_data[:10]}") # Log the first 10 laser scan values | ||
|
||
observation = self.scan_data.astype(np.float32) | ||
reward = self.calculate_reward() | ||
done = self.check_done() | ||
|
||
# Log the action and the corresponding observation | ||
self.node.get_logger().info(f"Action: {action}, Min Distance: {np.min(self.scan_data)}, Reward: {reward}") | ||
|
||
return observation, reward, done, False, {} | ||
|
||
def calculate_reward(self): | ||
min_distance = np.min(self.scan_data) | ||
|
||
# Collision penalty | ||
if min_distance < 0.2: | ||
return -300.0 # High penalty for collision | ||
|
||
# Reward for moving forward | ||
forward_movement_reward = 1.0 # Maintain a modest reward for moving forward | ||
|
||
# Penalize proximity to obstacles more heavily | ||
proximity_penalty = -50.0 * (0.5 - min_distance) if min_distance < 0.5 else 0.0 | ||
|
||
# Reward for maintaining a safe distance | ||
safe_distance_reward = 15.0 if min_distance > 0.5 else 0.0 | ||
|
||
# Small reward for each step the robot avoids collision (time-based reward) | ||
time_step_reward = 0.2 # Encourage staying in a safe state over time | ||
|
||
# Combine rewards and penalties | ||
reward = forward_movement_reward + proximity_penalty + safe_distance_reward + time_step_reward | ||
|
||
return reward | ||
|
||
def check_done(self): | ||
min_distance = np.min(self.scan_data) | ||
if min_distance < 0.2: | ||
return True | ||
return False | ||
|
||
def close(self): | ||
self.executor.shutdown() | ||
self.spin_thread.join() | ||
rclpy.shutdown() | ||
|
||
def main(): | ||
rclpy.init(args=None) | ||
env = TurtleBot3Env() | ||
|
||
# Example test loop | ||
for _ in range(100): # Run for 100 steps as an example | ||
action = env.action_space.sample() # Random action | ||
observation, reward, done, _, _ = env.step(action) | ||
if done: | ||
env.reset() | ||
|
||
env.close() | ||
|
||
if __name__ == '__main__': | ||
main() |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from stable_baselines3 import PPO | ||
from turtlebot3_rl.env import TurtleBot3Env | ||
import numpy as np | ||
|
||
def main(): | ||
env = TurtleBot3Env() | ||
|
||
# Load the trained model | ||
model = PPO.load("ppo_turtlebot3") | ||
|
||
obs, _ = env.reset() | ||
for _ in range(20): # Test for a small number of steps | ||
action, _states = model.predict(obs, deterministic=True) | ||
obs, reward, done, truncated, info = env.step(action) | ||
print(f"Action: {action}, Min Distance: {np.min(obs)}, Reward: {reward}, Done: {done}") | ||
if done or truncated: | ||
obs, _ = env.reset() | ||
|
||
env.close() | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import gym | ||
from stable_baselines3 import PPO | ||
from stable_baselines3.common.env_checker import check_env | ||
from turtlebot3_rl.env import TurtleBot3Env | ||
import numpy as np | ||
np.bool = np.bool_ | ||
|
||
|
||
def main(): | ||
env = TurtleBot3Env() | ||
check_env(env) | ||
|
||
# Define the RL algorithm (PPO in this case) | ||
model = PPO("MlpPolicy", env, ent_coef=0.01, learning_rate=0.000001, n_steps=2048, batch_size=64, n_epochs=10, clip_range=0.1, verbose=1) | ||
|
||
# Train the model | ||
model.learn(total_timesteps=1000000)#50000 | ||
|
||
# Save the model | ||
model.save("ppo_turtlebot3") | ||
|
||
# Close the environment | ||
env.close() | ||
|
||
if __name__ == '__main__': | ||
main() |