From c30b3e0baadebf913a48bf684d44cac7d86d76d3 Mon Sep 17 00:00:00 2001 From: Jihoon Oh Date: Sat, 13 Jan 2024 20:31:39 +0900 Subject: [PATCH] update --- .github/workflows/peripheral.yml | 14 ++--- node_scripts/cutie_node.py | 24 +++------ node_scripts/deva_node.py | 11 ++-- node_scripts/grounding_dino_node.py | 13 ++--- node_scripts/model_config.py | 26 +++------- node_scripts/sam_node.py | 31 +++-------- node_scripts/utils.py | 8 +-- pyproject.toml | 39 ++++++++++++++ scripts/install_trained_data.py | 79 ++++++++++++++--------------- test/test_node.py | 31 +++++------ 10 files changed, 125 insertions(+), 151 deletions(-) create mode 100644 pyproject.toml diff --git a/.github/workflows/peripheral.yml b/.github/workflows/peripheral.yml index b63cf6e..056db90 100644 --- a/.github/workflows/peripheral.yml +++ b/.github/workflows/peripheral.yml @@ -1,4 +1,4 @@ -name: Peripheral test +name: formatter and checker for python on: push: @@ -14,10 +14,10 @@ jobs: steps: - name: Checkout Code uses: actions/checkout@v2 - - name: pip install formatters and mypy + - name: pip install formatters and checkers run: | - pip3 install mypy flake8 isort - - name: pip install torch and numpy for build + pip3 install mypy isort black pyproject-flake8 + - name: pip install torch and numpy for module build run: | pip3 install numpy>=1.24 torch>=2.0 @@ -27,7 +27,7 @@ jobs: mypy --version mypy . - - name: check by isrot and flake8 + - name: reformat and check run: | - python3 -m isort example/ test/ node_script/ - python3 -m flake8 example/ test/ node_script/ + python3 -m black . + python3 -m pflake8 node_scripts/ scripts/ test/ diff --git a/node_scripts/cutie_node.py b/node_scripts/cutie_node.py index eb6b66d..4be4ab8 100644 --- a/node_scripts/cutie_node.py +++ b/node_scripts/cutie_node.py @@ -14,9 +14,7 @@ from utils import overlay_davis -class CutieNode( - object -): # should not be ConnectionBasedNode cause Cutie tracker needs continuous input +class CutieNode(object): # should not be ConnectionBasedNode cause Cutie tracker needs continuous input def __init__(self): super(CutieNode, self).__init__() self.cutie_config = CutieConfig.from_rosparam() @@ -33,12 +31,8 @@ def __init__(self): queue_size=1, buff_size=2**24, ) - self.pub_vis_img = rospy.Publisher( - "~output/segmentation_image", Image, queue_size=1 - ) - self.pub_segmentation_img = rospy.Publisher( - "~output/segmentation", Image, queue_size=1 - ) + self.pub_vis_img = rospy.Publisher("~output/segmentation_image", Image, queue_size=1) + self.pub_segmentation_img = rospy.Publisher("~output/segmentation", Image, queue_size=1) @torch.inference_mode() def initialize(self): @@ -51,9 +45,7 @@ def initialize(self): # initialize the model with the mask with torch.cuda.amp.autocast(enabled=True): image_torch = ( - torch.from_numpy(self.image.transpose(2, 0, 1)) - .float() - .to(self.cutie_config.device, non_blocking=True) + torch.from_numpy(self.image.transpose(2, 0, 1)).float().to(self.cutie_config.device, non_blocking=True) / 255 ) # initialize with the mask @@ -86,15 +78,11 @@ def callback(self, img_msg): self.image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8") with torch.cuda.amp.autocast(enabled=True): image_torch = ( - torch.from_numpy(self.image.transpose(2, 0, 1)) - .float() - .to(self.cutie_config.device, non_blocking=True) + torch.from_numpy(self.image.transpose(2, 0, 1)).float().to(self.cutie_config.device, non_blocking=True) / 255 ) prediction = self.predictor.step(image_torch) - self.mask = ( - torch.max(prediction, dim=0).indices.cpu().numpy().astype(np.uint8) - ) + self.mask = torch.max(prediction, dim=0).indices.cpu().numpy().astype(np.uint8) self.visualization = overlay_davis(self.image.copy(), self.mask) if self.with_bbox and len(np.unique(self.mask)) > 1: masks = [] diff --git a/node_scripts/deva_node.py b/node_scripts/deva_node.py index 72f9e38..890e602 100644 --- a/node_scripts/deva_node.py +++ b/node_scripts/deva_node.py @@ -84,15 +84,12 @@ def callback(self, img_msg): self.classes, min_size, ) - prob = self.deva_predictor.incorporate_detection( - deva_input, incorporate_mask, segments_info - ) + prob = self.deva_predictor.incorporate_detection(deva_input, incorporate_mask, segments_info) self.object_ids = [seg.id for seg in segments_info] self.category_ids = [seg.category_ids[0] for seg in segments_info] self.scores = [seg.scores[0] for seg in segments_info] self.labels_with_scores = [ - f"{self.classes[seg.category_ids[0]]} {seg.scores[0]:.2f}" - for seg in segments_info + f"{self.classes[seg.category_ids[0]]} {seg.scores[0]:.2f}" for seg in segments_info ] self.cnt = 1 else: @@ -121,9 +118,7 @@ def callback(self, img_msg): label_names = [self.classes[cls_id] for cls_id in detections.class_id] label_array = LabelArray() - label_array.labels = [ - Label(id=i + 1, name=name) for i, name in enumerate(label_names) - ] + label_array.labels = [Label(id=i + 1, name=name) for i, name in enumerate(label_names)] label_array.header.stamp = rospy.Time.now() label_array.header.frame_id = img_msg.header.frame_id class_result = ClassificationResult( diff --git a/node_scripts/grounding_dino_node.py b/node_scripts/grounding_dino_node.py index 1233036..93bd209 100644 --- a/node_scripts/grounding_dino_node.py +++ b/node_scripts/grounding_dino_node.py @@ -53,9 +53,7 @@ def config_cb(self, config, level): def publish_result(self, boxes, label_names, scores, vis, frame_id): if label_names is not None: label_array = LabelArray() - label_array.labels = [ - Label(id=i + 1, name=name) for i, name in enumerate(label_names) - ] + label_array.labels = [Label(id=i + 1, name=name) for i, name in enumerate(label_names)] label_array.header.stamp = rospy.Time.now() label_array.header.frame_id = frame_id self.pub_labels.publish(label_array) @@ -115,18 +113,13 @@ def callback(self, img_msg): labels = [self.classes[cls_id] for cls_id in detections.class_id] scores = detections.confidence.tolist() - labels_with_scores = [ - f"{label} {score:.2f}" for label, score in zip(labels, scores) - ] + labels_with_scores = [f"{label} {score:.2f}" for label, score in zip(labels, scores)] box_annotator = sv.BoxAnnotator() self.visualization = box_annotator.annotate( scene=self.image.copy(), detections=detections, labels=labels_with_scores ) - self.publish_result( - detections.xyxy, labels, scores, self.visualization, img_msg.header.frame_id - ) - + self.publish_result(detections.xyxy, labels, scores, self.visualization, img_msg.header.frame_id) if __name__ == "__main__": diff --git a/node_scripts/model_config.py b/node_scripts/model_config.py index 6fd1863..6a1c360 100644 --- a/node_scripts/model_config.py +++ b/node_scripts/model_config.py @@ -5,9 +5,7 @@ import rospkg import torch -CKECKPOINT_ROOT = os.path.join( - rospkg.RosPack().get_path("tracking_ros"), "trained_data" -) +CKECKPOINT_ROOT = os.path.join(rospkg.RosPack().get_path("tracking_ros"), "trained_data") @dataclass @@ -68,20 +66,12 @@ def get_predictor(self): SamAutomaticMaskGenerator, SamPredictor, ) - model = sam_model_registry[self.model_type[:5]]( - checkpoint=self.model_checkpoints[self.model_type] - ) + model = sam_model_registry[self.model_type[:5]](checkpoint=self.model_checkpoints[self.model_type]) model.to(device=self.device).eval() - return ( - SamPredictor(model) - if self.mode == "prompt" - else SamAutomaticMaskGenerator(model) - ) + return SamPredictor(model) if self.mode == "prompt" else SamAutomaticMaskGenerator(model) @classmethod - def from_args( - cls, model_type: str = "vit_t", mode: str = "prompt", device: str = "cuda:0" - ): + def from_args(cls, model_type: str = "vit_t", mode: str = "prompt", device: str = "cuda:0"): return cls(model_name="SAM", model_type=model_type, mode=mode, device=device) @classmethod @@ -193,12 +183,8 @@ def from_rosparam(cls): @dataclass class GroundingDINOConfig(ROSInferenceModelConfig): - model_config = os.path.join( - CKECKPOINT_ROOT, "groundingdino/GroundingDINO_SwinT_OGC.py" - ) - model_checkpoint = os.path.join( - CKECKPOINT_ROOT, "groundingdino/groundingdino_swint_ogc.pth" - ) + model_config = os.path.join(CKECKPOINT_ROOT, "groundingdino/GroundingDINO_SwinT_OGC.py") + model_checkpoint = os.path.join(CKECKPOINT_ROOT, "groundingdino/groundingdino_swint_ogc.pth") def get_predictor(self): try: diff --git a/node_scripts/sam_node.py b/node_scripts/sam_node.py index e8871c7..73307da 100644 --- a/node_scripts/sam_node.py +++ b/node_scripts/sam_node.py @@ -56,12 +56,8 @@ def __init__(self): ) # number of masks to generate automatically, in order of score self.bridge = CvBridge() - self.pub_segmentation_img = self.advertise( - "~output/segmentation", Image, queue_size=1 - ) - self.pub_vis_img = self.advertise( - "~output/segmentation_image", Image, queue_size=1 - ) + self.pub_segmentation_img = self.advertise("~output/segmentation", Image, queue_size=1) + self.pub_vis_img = self.advertise("~output/segmentation_image", Image, queue_size=1) # initilize prompt self.publish_mask = False # trigger to publish mask @@ -166,12 +162,7 @@ def point_callback(self, point_msg): point_x = int(point_msg.point.x) # x is within 0 ~ width point_y = int(point_msg.point.y) # y is within 0 ~ height - if ( - point_x < 1 - or point_x > (self.image.shape[1] - 1) - or point_y < 1 - or point_y > (self.image.shape[0] - 1) - ): + if point_x < 1 or point_x > (self.image.shape[1] - 1) or point_y < 1 or point_y > (self.image.shape[0] - 1): rospy.logwarn("point {} is out of image shape".format([point_x, point_y])) return @@ -274,18 +265,12 @@ def callback(self, img_msg): if self.prompt_mask is not None: # if prompt mask exists paint_mask = self.mask.copy() paint_mask[self.prompt_mask] = self.num_mask + 1 - self.visualization = overlay_davis( - self.visualization, paint_mask - ) + self.visualization = overlay_davis(self.visualization, paint_mask) else: # if prompt mask does not exist - self.visualization = overlay_davis( - self.visualization, self.mask - ) + self.visualization = overlay_davis(self.visualization, self.mask) else: # when initial state if self.prompt_mask is not None: # if prompt mask exists - self.visualization = overlay_davis( - self.visualization, self.prompt_mask.astype(np.uint8) - ) + self.visualization = overlay_davis(self.visualization, self.prompt_mask.astype(np.uint8)) self.visualization = draw_prompt( self.visualization, self.prompt_points, @@ -299,9 +284,7 @@ def callback(self, img_msg): masks = self.predictor.generate(self.image) # dict of masks self.masks = [mask["segmentation"] for mask in masks] # list of [H, W] if self.num_slots > 0: - self.masks = [mask["segmentation"] for mask in masks][ - : self.num_slots - ] + self.masks = [mask["segmentation"] for mask in masks][: self.num_slots] self.mask = np.zeros_like(self.masks[0]).astype(np.uint8) # [H, W] for i, mask in enumerate(self.masks): self.mask = np.clip( diff --git a/node_scripts/utils.py b/node_scripts/utils.py index dc97123..704a305 100644 --- a/node_scripts/utils.py +++ b/node_scripts/utils.py @@ -31,9 +31,7 @@ def draw_bbox(image: np.ndarray, bbox, label) -> np.ndarray: mask=None, ) box_annotator = sv.BoxAnnotator() - result_image = box_annotator.annotate( - scene=image, detections=detections, labels=[label] - ) + result_image = box_annotator.annotate(scene=image, detections=detections, labels=[label]) return result_image.astype(np.uint8) @@ -43,9 +41,7 @@ def draw_prompt(image: np.ndarray, points, prompt_labels, bbox, label) -> np.nda return result_image.astype(np.uint8) -def overlay_davis( - image: np.ndarray, mask: np.ndarray, alpha: float = 0.5, fade: bool = False -): +def overlay_davis(image: np.ndarray, mask: np.ndarray, alpha: float = 0.5, fade: bool = False): """Overlay segmentation on top of RGB image. from davis official""" im_overlay = image.copy() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..38b1609 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,39 @@ +[tool.flake8] +exclude = "Cutie" +ignore = "H,E501,W503,F841" + +[tool.isort] +profile = "black" +line_length = 120 + +[tool.black] +line-length = 120 +target-version = ["py38", "py39"] +exclude = "Cutie" + +[tool.mypy] +python_version = 3.9 +exclude = "Cutie" +show_error_codes = true +warn_unused_ignores = false +check_untyped_defs = true + +[[tool.mypy.overrides]] +module = [ + "rospy", + "rostest", + "cv_bridge", + "dynamic_reconfigure.*", + "tracking_ros.cfg", + "sensor_msgs.msg", + "jsk_recognition_msgs.msg", + "jsk_topic_tools", + "jsk_data", + "torchvision", + "supervision", + "deva.*", + "cutie.*", + "groundingdino" +] +ignore_missing_imports = true + diff --git a/scripts/install_trained_data.py b/scripts/install_trained_data.py index 1ab0d1c..1bda57a 100644 --- a/scripts/install_trained_data.py +++ b/scripts/install_trained_data.py @@ -5,99 +5,98 @@ import jsk_data + def download_data(*args, **kwargs): - p = multiprocessing.Process( - target=jsk_data.download_data, - args=args, - kwargs=kwargs) + p = multiprocessing.Process(target=jsk_data.download_data, args=args, kwargs=kwargs) p.start() def main(): parser = argparse.ArgumentParser() - parser.add_argument('-v', '--verbose', dest='quiet', action='store_false') + parser.add_argument("-v", "--verbose", dest="quiet", action="store_false") args = parser.parse_args() args.quiet - PKG = 'tracking_ros' + PKG = "tracking_ros" # segment anything download_data( pkg_name=PKG, - path='trained_data/sam/sam_vit_b.pth', - url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth', - md5='01ec64d29a2fca3f0661936605ae66f8', + path="trained_data/sam/sam_vit_b.pth", + url="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + md5="01ec64d29a2fca3f0661936605ae66f8", ) download_data( pkg_name=PKG, - path='trained_data/sam/sam_vit_l.pth', - url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', - md5='0b3195507c641ddb6910d2bb5adee89c', + path="trained_data/sam/sam_vit_l.pth", + url="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + md5="0b3195507c641ddb6910d2bb5adee89c", ) download_data( pkg_name=PKG, - path='trained_data/sam/sam_vit_h.pth', - url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', - md5='4b8939a88964f0f4ff5f5b2642c598a6', + path="trained_data/sam/sam_vit_h.pth", + url="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + md5="4b8939a88964f0f4ff5f5b2642c598a6", ) # segment anything hq download_data( pkg_name=PKG, - path='trained_data/sam/sam_vit_b_hq.pth', - url='https://drive.google.com/uc?id=11yExZLOve38kRZPfRx_MRxfIAKmfMY47', - md5='c6b8953247bcfdc8bb8ef91e36a6cacc', + path="trained_data/sam/sam_vit_b_hq.pth", + url="https://drive.google.com/uc?id=11yExZLOve38kRZPfRx_MRxfIAKmfMY47", + md5="c6b8953247bcfdc8bb8ef91e36a6cacc", ) download_data( pkg_name=PKG, - path='trained_data/sam/sam_vit_l_hq.pth', - url='https://drive.google.com/uc?id=1Uk17tDKX1YAKas5knI4y9ZJCo0lRVL0G', - md5='08947267966e4264fb39523eccc33f86', + path="trained_data/sam/sam_vit_l_hq.pth", + url="https://drive.google.com/uc?id=1Uk17tDKX1YAKas5knI4y9ZJCo0lRVL0G", + md5="08947267966e4264fb39523eccc33f86", ) download_data( pkg_name=PKG, - path='trained_data/sam/sam_vit_h_hq.pth', - url='https://drive.google.com/uc?id=1qobFYrI4eyIANfBSmYcGuWRaSIXfMOQ8', - md5='3560f6b6a5a6edacd814a1325c39640a', + path="trained_data/sam/sam_vit_h_hq.pth", + url="https://drive.google.com/uc?id=1qobFYrI4eyIANfBSmYcGuWRaSIXfMOQ8", + md5="3560f6b6a5a6edacd814a1325c39640a", ) # mobile sam download_data( pkg_name=PKG, - path='trained_data/sam/mobile_sam.pth', - url='https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/mobile_sam.pt', - md5='f3c0d8cda613564d499310dab6c812cd', + path="trained_data/sam/mobile_sam.pth", + url="https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/mobile_sam.pt", + md5="f3c0d8cda613564d499310dab6c812cd", ) # cutie download_data( pkg_name=PKG, - path='trained_data/cutie/cutie-base-mega.pth', - url='https://github.com/hkchengrex/Cutie/releases/download/v1.0/cutie-base-mega.pth', - md5='a6071de6136982e396851903ab4c083a', + path="trained_data/cutie/cutie-base-mega.pth", + url="https://github.com/hkchengrex/Cutie/releases/download/v1.0/cutie-base-mega.pth", + md5="a6071de6136982e396851903ab4c083a", ) # DEVA download_data( pkg_name=PKG, - path='trained_data/deva/DEVA-propagation.pth', - url='https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/DEVA-propagation.pth', - md5='a614cc9737a5b4c22ecbdc93e7842ecb', + path="trained_data/deva/DEVA-propagation.pth", + url="https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/DEVA-propagation.pth", + md5="a614cc9737a5b4c22ecbdc93e7842ecb", ) # grounding dino download_data( pkg_name=PKG, - path='trained_data/groundingdino/groundingdino_swint_ogc.pth', - url='https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth', - md5='075ebfa7242d913f38cb051fe1a128a2', + path="trained_data/groundingdino/groundingdino_swint_ogc.pth", + url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", + md5="075ebfa7242d913f38cb051fe1a128a2", ) download_data( pkg_name=PKG, - path='trained_data/groundingdino/GroundingDINO_SwinT_OGC.py', - url='https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/GroundingDINO_SwinT_OGC.py', - md5='bdb07fc17b611d622633d133d2cf873a', + path="trained_data/groundingdino/GroundingDINO_SwinT_OGC.py", + url="https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/GroundingDINO_SwinT_OGC.py", + md5="bdb07fc17b611d622633d133d2cf873a", ) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/test/test_node.py b/test/test_node.py index 320db23..ffdbeac 100644 --- a/test/test_node.py +++ b/test/test_node.py @@ -7,15 +7,13 @@ import rostest from sensor_msgs.msg import Image -from detic_ros.msg import SegmentationInfo - class TestNode(unittest.TestCase): def setUp(self): pass def test_mytest2(self): - cb_data = {'msg1': None, 'msg2': None, 'msg3': None, 'msg4': None, 'msg5': None} + cb_data = {"msg1": None, "msg2": None, "msg3": None, "msg4": None, "msg5": None} def all_subscribed() -> bool: rospy.logwarn("received: {}".format([not not v for v in cb_data.values()])) @@ -23,25 +21,24 @@ def all_subscribed() -> bool: return all(bool_list) def cb_debug_image(msg): - cb_data['msg1'] = msg + cb_data["msg1"] = msg def cb_debug_segimage(msg): - cb_data['msg2'] = msg + cb_data["msg2"] = msg def cb_info(msg): - cb_data['msg3'] = msg + cb_data["msg3"] = msg def cb_test_image(msg): - cb_data['msg4'] = msg + cb_data["msg4"] = msg def cb_test_image_filter(msg): - cb_data['msg5'] = msg + cb_data["msg5"] = msg - rospy.Subscriber('/docker/detic_segmentor/debug_image', Image, cb_debug_image) - rospy.Subscriber('/docker/detic_segmentor/debug_segmentation_image', Image, cb_debug_segimage) - rospy.Subscriber('/docker/detic_segmentor/segmentation_info', SegmentationInfo, cb_info) - rospy.Subscriber('/test_out_image', Image, cb_test_image) - rospy.Subscriber('/test_out_image_filter', Image, cb_test_image_filter) + rospy.Subscriber("/docker/detic_segmentor/debug_image", Image, cb_debug_image) + rospy.Subscriber("/docker/detic_segmentor/debug_segmentation_image", Image, cb_debug_segimage) + rospy.Subscriber("/test_out_image", Image, cb_test_image) + rospy.Subscriber("/test_out_image_filter", Image, cb_test_image_filter) time_out = 40 for _ in range(time_out): @@ -51,8 +48,6 @@ def cb_test_image_filter(msg): assert False -if __name__ == '__main__': - rospy.init_node('test_sample') - rostest.rosrun('detic_ros', - 'test_node', - TestNode) +if __name__ == "__main__": + rospy.init_node("test_sample") + rostest.rosrun("detic_ros", "test_node", TestNode)