Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh6404 committed Jan 13, 2024
1 parent b9064b8 commit c30b3e0
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 151 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/peripheral.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Peripheral test
name: formatter and checker for python

on:
push:
Expand All @@ -14,10 +14,10 @@ jobs:
steps:
- name: Checkout Code
uses: actions/checkout@v2
- name: pip install formatters and mypy
- name: pip install formatters and checkers
run: |
pip3 install mypy flake8 isort
- name: pip install torch and numpy for build
pip3 install mypy isort black pyproject-flake8
- name: pip install torch and numpy for module build
run: |
pip3 install numpy>=1.24 torch>=2.0
Expand All @@ -27,7 +27,7 @@ jobs:
mypy --version
mypy .
- name: check by isrot and flake8
- name: reformat and check
run: |
python3 -m isort example/ test/ node_script/
python3 -m flake8 example/ test/ node_script/
python3 -m black .
python3 -m pflake8 node_scripts/ scripts/ test/
24 changes: 6 additions & 18 deletions node_scripts/cutie_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
from utils import overlay_davis


class CutieNode(
object
): # should not be ConnectionBasedNode cause Cutie tracker needs continuous input
class CutieNode(object): # should not be ConnectionBasedNode cause Cutie tracker needs continuous input
def __init__(self):
super(CutieNode, self).__init__()
self.cutie_config = CutieConfig.from_rosparam()
Expand All @@ -33,12 +31,8 @@ def __init__(self):
queue_size=1,
buff_size=2**24,
)
self.pub_vis_img = rospy.Publisher(
"~output/segmentation_image", Image, queue_size=1
)
self.pub_segmentation_img = rospy.Publisher(
"~output/segmentation", Image, queue_size=1
)
self.pub_vis_img = rospy.Publisher("~output/segmentation_image", Image, queue_size=1)
self.pub_segmentation_img = rospy.Publisher("~output/segmentation", Image, queue_size=1)

@torch.inference_mode()
def initialize(self):
Expand All @@ -51,9 +45,7 @@ def initialize(self):
# initialize the model with the mask
with torch.cuda.amp.autocast(enabled=True):
image_torch = (
torch.from_numpy(self.image.transpose(2, 0, 1))
.float()
.to(self.cutie_config.device, non_blocking=True)
torch.from_numpy(self.image.transpose(2, 0, 1)).float().to(self.cutie_config.device, non_blocking=True)
/ 255
)
# initialize with the mask
Expand Down Expand Up @@ -86,15 +78,11 @@ def callback(self, img_msg):
self.image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8")
with torch.cuda.amp.autocast(enabled=True):
image_torch = (
torch.from_numpy(self.image.transpose(2, 0, 1))
.float()
.to(self.cutie_config.device, non_blocking=True)
torch.from_numpy(self.image.transpose(2, 0, 1)).float().to(self.cutie_config.device, non_blocking=True)
/ 255
)
prediction = self.predictor.step(image_torch)
self.mask = (
torch.max(prediction, dim=0).indices.cpu().numpy().astype(np.uint8)
)
self.mask = torch.max(prediction, dim=0).indices.cpu().numpy().astype(np.uint8)
self.visualization = overlay_davis(self.image.copy(), self.mask)
if self.with_bbox and len(np.unique(self.mask)) > 1:
masks = []
Expand Down
11 changes: 3 additions & 8 deletions node_scripts/deva_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,12 @@ def callback(self, img_msg):
self.classes,
min_size,
)
prob = self.deva_predictor.incorporate_detection(
deva_input, incorporate_mask, segments_info
)
prob = self.deva_predictor.incorporate_detection(deva_input, incorporate_mask, segments_info)
self.object_ids = [seg.id for seg in segments_info]
self.category_ids = [seg.category_ids[0] for seg in segments_info]
self.scores = [seg.scores[0] for seg in segments_info]
self.labels_with_scores = [
f"{self.classes[seg.category_ids[0]]} {seg.scores[0]:.2f}"
for seg in segments_info
f"{self.classes[seg.category_ids[0]]} {seg.scores[0]:.2f}" for seg in segments_info
]
self.cnt = 1
else:
Expand Down Expand Up @@ -121,9 +118,7 @@ def callback(self, img_msg):

label_names = [self.classes[cls_id] for cls_id in detections.class_id]
label_array = LabelArray()
label_array.labels = [
Label(id=i + 1, name=name) for i, name in enumerate(label_names)
]
label_array.labels = [Label(id=i + 1, name=name) for i, name in enumerate(label_names)]
label_array.header.stamp = rospy.Time.now()
label_array.header.frame_id = img_msg.header.frame_id
class_result = ClassificationResult(
Expand Down
13 changes: 3 additions & 10 deletions node_scripts/grounding_dino_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def config_cb(self, config, level):
def publish_result(self, boxes, label_names, scores, vis, frame_id):
if label_names is not None:
label_array = LabelArray()
label_array.labels = [
Label(id=i + 1, name=name) for i, name in enumerate(label_names)
]
label_array.labels = [Label(id=i + 1, name=name) for i, name in enumerate(label_names)]
label_array.header.stamp = rospy.Time.now()
label_array.header.frame_id = frame_id
self.pub_labels.publish(label_array)
Expand Down Expand Up @@ -115,18 +113,13 @@ def callback(self, img_msg):

labels = [self.classes[cls_id] for cls_id in detections.class_id]
scores = detections.confidence.tolist()
labels_with_scores = [
f"{label} {score:.2f}" for label, score in zip(labels, scores)
]
labels_with_scores = [f"{label} {score:.2f}" for label, score in zip(labels, scores)]

box_annotator = sv.BoxAnnotator()
self.visualization = box_annotator.annotate(
scene=self.image.copy(), detections=detections, labels=labels_with_scores
)
self.publish_result(
detections.xyxy, labels, scores, self.visualization, img_msg.header.frame_id
)

self.publish_result(detections.xyxy, labels, scores, self.visualization, img_msg.header.frame_id)


if __name__ == "__main__":
Expand Down
26 changes: 6 additions & 20 deletions node_scripts/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import rospkg
import torch

CKECKPOINT_ROOT = os.path.join(
rospkg.RosPack().get_path("tracking_ros"), "trained_data"
)
CKECKPOINT_ROOT = os.path.join(rospkg.RosPack().get_path("tracking_ros"), "trained_data")


@dataclass
Expand Down Expand Up @@ -68,20 +66,12 @@ def get_predictor(self):
SamAutomaticMaskGenerator,
SamPredictor,
)
model = sam_model_registry[self.model_type[:5]](
checkpoint=self.model_checkpoints[self.model_type]
)
model = sam_model_registry[self.model_type[:5]](checkpoint=self.model_checkpoints[self.model_type])
model.to(device=self.device).eval()
return (
SamPredictor(model)
if self.mode == "prompt"
else SamAutomaticMaskGenerator(model)
)
return SamPredictor(model) if self.mode == "prompt" else SamAutomaticMaskGenerator(model)

@classmethod
def from_args(
cls, model_type: str = "vit_t", mode: str = "prompt", device: str = "cuda:0"
):
def from_args(cls, model_type: str = "vit_t", mode: str = "prompt", device: str = "cuda:0"):
return cls(model_name="SAM", model_type=model_type, mode=mode, device=device)

@classmethod
Expand Down Expand Up @@ -193,12 +183,8 @@ def from_rosparam(cls):

@dataclass
class GroundingDINOConfig(ROSInferenceModelConfig):
model_config = os.path.join(
CKECKPOINT_ROOT, "groundingdino/GroundingDINO_SwinT_OGC.py"
)
model_checkpoint = os.path.join(
CKECKPOINT_ROOT, "groundingdino/groundingdino_swint_ogc.pth"
)
model_config = os.path.join(CKECKPOINT_ROOT, "groundingdino/GroundingDINO_SwinT_OGC.py")
model_checkpoint = os.path.join(CKECKPOINT_ROOT, "groundingdino/groundingdino_swint_ogc.pth")

def get_predictor(self):
try:
Expand Down
31 changes: 7 additions & 24 deletions node_scripts/sam_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,8 @@ def __init__(self):
) # number of masks to generate automatically, in order of score

self.bridge = CvBridge()
self.pub_segmentation_img = self.advertise(
"~output/segmentation", Image, queue_size=1
)
self.pub_vis_img = self.advertise(
"~output/segmentation_image", Image, queue_size=1
)
self.pub_segmentation_img = self.advertise("~output/segmentation", Image, queue_size=1)
self.pub_vis_img = self.advertise("~output/segmentation_image", Image, queue_size=1)

# initilize prompt
self.publish_mask = False # trigger to publish mask
Expand Down Expand Up @@ -166,12 +162,7 @@ def point_callback(self, point_msg):
point_x = int(point_msg.point.x) # x is within 0 ~ width
point_y = int(point_msg.point.y) # y is within 0 ~ height

if (
point_x < 1
or point_x > (self.image.shape[1] - 1)
or point_y < 1
or point_y > (self.image.shape[0] - 1)
):
if point_x < 1 or point_x > (self.image.shape[1] - 1) or point_y < 1 or point_y > (self.image.shape[0] - 1):
rospy.logwarn("point {} is out of image shape".format([point_x, point_y]))
return

Expand Down Expand Up @@ -274,18 +265,12 @@ def callback(self, img_msg):
if self.prompt_mask is not None: # if prompt mask exists
paint_mask = self.mask.copy()
paint_mask[self.prompt_mask] = self.num_mask + 1
self.visualization = overlay_davis(
self.visualization, paint_mask
)
self.visualization = overlay_davis(self.visualization, paint_mask)
else: # if prompt mask does not exist
self.visualization = overlay_davis(
self.visualization, self.mask
)
self.visualization = overlay_davis(self.visualization, self.mask)
else: # when initial state
if self.prompt_mask is not None: # if prompt mask exists
self.visualization = overlay_davis(
self.visualization, self.prompt_mask.astype(np.uint8)
)
self.visualization = overlay_davis(self.visualization, self.prompt_mask.astype(np.uint8))
self.visualization = draw_prompt(
self.visualization,
self.prompt_points,
Expand All @@ -299,9 +284,7 @@ def callback(self, img_msg):
masks = self.predictor.generate(self.image) # dict of masks
self.masks = [mask["segmentation"] for mask in masks] # list of [H, W]
if self.num_slots > 0:
self.masks = [mask["segmentation"] for mask in masks][
: self.num_slots
]
self.masks = [mask["segmentation"] for mask in masks][: self.num_slots]
self.mask = np.zeros_like(self.masks[0]).astype(np.uint8) # [H, W]
for i, mask in enumerate(self.masks):
self.mask = np.clip(
Expand Down
8 changes: 2 additions & 6 deletions node_scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ def draw_bbox(image: np.ndarray, bbox, label) -> np.ndarray:
mask=None,
)
box_annotator = sv.BoxAnnotator()
result_image = box_annotator.annotate(
scene=image, detections=detections, labels=[label]
)
result_image = box_annotator.annotate(scene=image, detections=detections, labels=[label])
return result_image.astype(np.uint8)


Expand All @@ -43,9 +41,7 @@ def draw_prompt(image: np.ndarray, points, prompt_labels, bbox, label) -> np.nda
return result_image.astype(np.uint8)


def overlay_davis(
image: np.ndarray, mask: np.ndarray, alpha: float = 0.5, fade: bool = False
):
def overlay_davis(image: np.ndarray, mask: np.ndarray, alpha: float = 0.5, fade: bool = False):
"""Overlay segmentation on top of RGB image. from davis official"""
im_overlay = image.copy()

Expand Down
39 changes: 39 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
[tool.flake8]
exclude = "Cutie"
ignore = "H,E501,W503,F841"

[tool.isort]
profile = "black"
line_length = 120

[tool.black]
line-length = 120
target-version = ["py38", "py39"]
exclude = "Cutie"

[tool.mypy]
python_version = 3.9
exclude = "Cutie"
show_error_codes = true
warn_unused_ignores = false
check_untyped_defs = true

[[tool.mypy.overrides]]
module = [
"rospy",
"rostest",
"cv_bridge",
"dynamic_reconfigure.*",
"tracking_ros.cfg",
"sensor_msgs.msg",
"jsk_recognition_msgs.msg",
"jsk_topic_tools",
"jsk_data",
"torchvision",
"supervision",
"deva.*",
"cutie.*",
"groundingdino"
]
ignore_missing_imports = true

Loading

0 comments on commit c30b3e0

Please sign in to comment.