diff --git a/node_scripts/deva_node.py b/node_scripts/deva_node.py index 66c2908..6b59966 100644 --- a/node_scripts/deva_node.py +++ b/node_scripts/deva_node.py @@ -3,6 +3,7 @@ import numpy as np import torch +import torch.nn.functional as F import supervision as sv import rospy @@ -14,7 +15,7 @@ from jsk_recognition_msgs.msg import ClassificationResult from jsk_recognition_msgs.msg import Label, LabelArray -from deva.inference.demo_utils import get_input_frame_for_deva +from deva.dataset.utils import im_normalization from deva.ext.grounding_dino import segment_with_text from model_config import SAMConfig, GroundingDINOConfig, DEVAConfig @@ -73,8 +74,9 @@ def publish_result(self, mask, vis, frame_id): def callback(self, img_msg): self.image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8") with torch.cuda.amp.autocast(enabled=self.cfg["amp"]): - min_size = min(self.image.shape[:2]) - deva_input = get_input_frame_for_deva(self.image, min_size).to(self.deva_config.device) + h, w = self.image.shape[:2] + deva_input = im_normalization(torch.from_numpy(self.image).permute(2, 0, 1).float() / 255).unsqueeze(0).to(self.deva_config.device) + deva_input = F.interpolate(deva_input, (h, w), mode='bilinear', align_corners=False)[0] if self.cnt % self.cfg["detection_every"] == 0: incorporate_mask, segments_info = segment_with_text( self.cfg, @@ -82,7 +84,7 @@ def callback(self, img_msg): self.sam_predictor, self.image, self.classes, - min_size, + min(h, w), ) prob = self.deva_predictor.incorporate_detection(deva_input, incorporate_mask, segments_info) self.object_ids = [seg.id for seg in segments_info]