Skip to content

Commit

Permalink
Merge pull request #1 from DHBW-Smart-Rollerz/fix/env-transform
Browse files Browse the repository at this point in the history
Changed Exe and original img size
  • Loading branch information
jonasweihing authored Oct 10, 2024
2 parents b3cf73c + d976a63 commit a0ea9a6
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 4 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ This repository contains the ros2 jazzy package for the ai lane detection based
pip install -r requirements.txt
```
3. Download the trained model and its corresponding configuration file. (See the [Usage](#usage) section for more details.)
4. Please set the following ENV variable in your .bashrc/.zshrc if not already done:
```bash
PYTHON_EXECUTABLE="/home/$USER/.pyenv/versions/default/bin/python3" # Change this to the python3 executable path of your pyenv
```

## Usage

Expand Down
47 changes: 44 additions & 3 deletions lane_detection_ai/model/model_wrapper.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,27 @@
import numpy as np
import torch
import torchvision.transforms as transforms
from camera_preprocessing.transformation.calibration import Calibration
from timing.timer import Timer

from lane_detection_ai.model.utils.common import get_config, get_model


class LaneDetectionAiModel:
"""LaneDetectionAiModel class"""
"""LaneDetectionAiModel class."""

def __init__(self, base_path: str, model_config_path: str):
"""
Initialize the LaneDetectionAiModel class.
Arguments:
base_path -- The base path to the model.
model_config_path -- The path to the model configuration file.
"""
torch.backends.cudnn.benchmark = True

self.camera_calibration = Calibration()
self.camera_calibration.setup()
self.config = get_config(os.path.join(base_path, model_config_path))
self.config.test_model = os.path.join(base_path, self.config.test_model)
self.image_transform = transforms.Compose(
Expand All @@ -28,6 +38,12 @@ def __init__(self, base_path: str, model_config_path: str):
self.net = self.load_model()

def load_model(self):
"""
Load the model.
Returns:
torch.nn.Module -- The model.
"""
self.config.batch_size = 1

assert self.config.backbone in [
Expand Down Expand Up @@ -60,6 +76,15 @@ def load_model(self):
return net

def predict(self, image: np.ndarray) -> List[np.ndarray]:
"""
Predict the lanes in the image.
Arguments:
image -- The image.
Returns:
List[np.ndarray] -- The lanes.
"""
with Timer(name="image_transform", filter_strength=40):
image = cv2.resize(
image,
Expand All @@ -82,8 +107,8 @@ def predict(self, image: np.ndarray) -> List[np.ndarray]:
pred,
self.config.row_anchor,
self.config.col_anchor,
original_image_width=2064,
original_image_height=1544,
original_image_width=self.camera_calibration.target_size[0],
original_image_height=self.camera_calibration.target_size[1],
)

if len(coords[0]) > 0 and len(coords[3]) > 0:
Expand Down Expand Up @@ -124,6 +149,22 @@ def pred2coords(
original_image_width=1640,
original_image_height=590,
):
"""
Convert the prediction to coordinates.
Arguments:
pred -- Prediction.
row_anchor -- Row anchor.
col_anchor -- Column anchor.
Keyword Arguments:
local_width -- Local Width (default: {1})
original_image_width -- Original Image width (default: {1640})
original_image_height -- Original Image height (default: {590})
Returns:
List[np.ndarray] -- The coordinates.
"""
batch_size, num_grid_row, num_cls_row, num_lane_row = pred["loc_row"].shape
batch_size, num_grid_col, num_cls_col, num_lane_col = pred["loc_col"].shape

Expand Down
3 changes: 2 additions & 1 deletion launch/lane_detection_ai.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def generate_launch_description():
Returns:
LaunchDescription -- The launch description.
"""
python_executable = os.getenv("PYTHON_EXECUTABLE", "/usr/bin/python3")
debug = LaunchConfiguration("debug")
params_file = LaunchConfiguration("params_file")

Expand All @@ -41,7 +42,7 @@ def generate_launch_description():
{"debug": debug},
params_file,
],
prefix="/home/smartrollerz/.pyenv/versions/3.12.5/bin/python3",
prefix=[python_executable],
),
]
)

0 comments on commit a0ea9a6

Please sign in to comment.