Skip to content

Commit

Permalink
8/14
Browse files Browse the repository at this point in the history
In this update, i  complete RL training, and testing, but the performance is still not good enough.
  • Loading branch information
StanleyChueh committed Aug 15, 2024
1 parent e430708 commit 7e7fa03
Show file tree
Hide file tree
Showing 12 changed files with 302 additions and 0 deletions.
18 changes: 18 additions & 0 deletions turtlebot3_rl/package.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>turtlebot3_rl</name>
<version>0.0.0</version>
<description>TODO: Package description</description>
<maintainer email="stanleychueh28@gmail.com">stanley</maintainer>
<license>TODO: License declaration</license>

<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>

<export>
<build_type>ament_python</build_type>
</export>
</package>
Empty file.
4 changes: 4 additions & 0 deletions turtlebot3_rl/setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[develop]
script-dir=$base/lib/turtlebot3_rl
[install]
install-scripts=$base/lib/turtlebot3_rl
28 changes: 28 additions & 0 deletions turtlebot3_rl/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from setuptools import setup

package_name = 'turtlebot3_rl'

setup(
name=package_name,
version='0.0.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='stanley',
maintainer_email='stanley@example.com',
description='RL package for TurtleBot3 in ROS2',
license='Apache License 2.0',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'train = turtlebot3_rl.train:main',
'test = turtlebot3_rl.test:main',
],
},
)

23 changes: 23 additions & 0 deletions turtlebot3_rl/test/test_copyright.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2015 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ament_copyright.main import main
import pytest


@pytest.mark.copyright
@pytest.mark.linter
def test_copyright():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found errors'
25 changes: 25 additions & 0 deletions turtlebot3_rl/test/test_flake8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2017 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ament_flake8.main import main_with_errors
import pytest


@pytest.mark.flake8
@pytest.mark.linter
def test_flake8():
rc, errors = main_with_errors(argv=[])
assert rc == 0, \
'Found %d code style errors / warnings:\n' % len(errors) + \
'\n'.join(errors)
23 changes: 23 additions & 0 deletions turtlebot3_rl/test/test_pep257.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2015 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ament_pep257.main import main
import pytest


@pytest.mark.linter
@pytest.mark.pep257
def test_pep257():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found code style errors / warnings'
Empty file.
133 changes: 133 additions & 0 deletions turtlebot3_rl/turtlebot3_rl/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.executors import MultiThreadedExecutor
from geometry_msgs.msg import Twist
from sensor_msgs.msg import LaserScan
import threading

class TurtleBot3Env(gym.Env):
def __init__(self):
super(TurtleBot3Env, self).__init__()

# Initialize ROS2 node
rclpy.init(args=None)
self.node = rclpy.create_node('turtlebot3_rl_env')

# Action space: linear and angular velocity
# Reduced linear velocity for more cautious movement
self.action_space = spaces.Box(low=np.array([-0.1, -0.05]), high=np.array([0.1, 0.05]), dtype=np.float32)

# Observation space: laser scan data
self.observation_space = spaces.Box(low=0.0, high=3.5, shape=(360,), dtype=np.float32)

# Publishers and subscribers
self.cmd_vel_pub = self.node.create_publisher(Twist, 'cmd_vel', 10)
self.scan_sub = self.node.create_subscription(
LaserScan,
'/scan',
self.scan_callback,
10 # Queue size
)

# Initialize scan data
self.scan_data = np.zeros(360, dtype=np.float32)

# Use the executor to manage the node spinning
self.executor = MultiThreadedExecutor()
self.executor.add_node(self.node)

# Start a separate thread to spin the ROS2 node
self.spin_thread = threading.Thread(target=self.executor.spin)
self.spin_thread.start()

def scan_callback(self, msg):
# Handle laser scan data
raw_scan_data = np.array(msg.ranges, dtype=np.float32)

# Replace 0.0 values with a large number, indicating 'no detection'
raw_scan_data[raw_scan_data == 0.0] = np.inf

# Ensure min range reflects realistic sensor limits
self.scan_data = np.clip(raw_scan_data, 0.12, 3.5)

def reset(self, seed=None, options=None):
# Reset the environment
self.node.get_logger().info("Resetting the environment...")

if seed is not None:
self.seed_val = seed
self._np_random, seed = gym.utils.seeding.np_random(seed)

return self.scan_data.astype(np.float32), {}

def step(self, action):
twist = Twist()
twist.linear.x = float(action[0])
twist.angular.z = float(action[1])
self.cmd_vel_pub.publish(twist)

# Log scan data for debugging
self.node.get_logger().info(f"Current scan data: {self.scan_data[:10]}") # Log the first 10 laser scan values

observation = self.scan_data.astype(np.float32)
reward = self.calculate_reward()
done = self.check_done()

# Log the action and the corresponding observation
self.node.get_logger().info(f"Action: {action}, Min Distance: {np.min(self.scan_data)}, Reward: {reward}")

return observation, reward, done, False, {}

def calculate_reward(self):
min_distance = np.min(self.scan_data)

# Collision penalty
if min_distance < 0.2:
return -300.0 # High penalty for collision

# Reward for moving forward
forward_movement_reward = 1.0 # Maintain a modest reward for moving forward

# Penalize proximity to obstacles more heavily
proximity_penalty = -50.0 * (0.5 - min_distance) if min_distance < 0.5 else 0.0

# Reward for maintaining a safe distance
safe_distance_reward = 15.0 if min_distance > 0.5 else 0.0

# Small reward for each step the robot avoids collision (time-based reward)
time_step_reward = 0.2 # Encourage staying in a safe state over time

# Combine rewards and penalties
reward = forward_movement_reward + proximity_penalty + safe_distance_reward + time_step_reward

return reward

def check_done(self):
min_distance = np.min(self.scan_data)
if min_distance < 0.2:
return True
return False

def close(self):
self.executor.shutdown()
self.spin_thread.join()
rclpy.shutdown()

def main():
rclpy.init(args=None)
env = TurtleBot3Env()

# Example test loop
for _ in range(100): # Run for 100 steps as an example
action = env.action_space.sample() # Random action
observation, reward, done, _, _ = env.step(action)
if done:
env.reset()

env.close()

if __name__ == '__main__':
main()
Binary file added turtlebot3_rl/turtlebot3_rl/ppo_turtlebot3.zip
Binary file not shown.
22 changes: 22 additions & 0 deletions turtlebot3_rl/turtlebot3_rl/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from stable_baselines3 import PPO
from turtlebot3_rl.env import TurtleBot3Env
import numpy as np

def main():
env = TurtleBot3Env()

# Load the trained model
model = PPO.load("ppo_turtlebot3")

obs, _ = env.reset()
for _ in range(20): # Test for a small number of steps
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = env.step(action)
print(f"Action: {action}, Min Distance: {np.min(obs)}, Reward: {reward}, Done: {done}")
if done or truncated:
obs, _ = env.reset()

env.close()

if __name__ == '__main__':
main()
26 changes: 26 additions & 0 deletions turtlebot3_rl/turtlebot3_rl/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from turtlebot3_rl.env import TurtleBot3Env
import numpy as np
np.bool = np.bool_


def main():
env = TurtleBot3Env()
check_env(env)

# Define the RL algorithm (PPO in this case)
model = PPO("MlpPolicy", env, ent_coef=0.01, learning_rate=0.000001, n_steps=2048, batch_size=64, n_epochs=10, clip_range=0.1, verbose=1)

# Train the model
model.learn(total_timesteps=1000000)#50000

# Save the model
model.save("ppo_turtlebot3")

# Close the environment
env.close()

if __name__ == '__main__':
main()

0 comments on commit 7e7fa03

Please sign in to comment.