diff --git a/butia_recognition/config/ram_network_config/.gitignore b/butia_recognition/config/ram_network_config/.gitignore new file mode 100644 index 0000000..b331dac --- /dev/null +++ b/butia_recognition/config/ram_network_config/.gitignore @@ -0,0 +1 @@ +*.pth \ No newline at end of file diff --git a/butia_recognition/scripts/butia_recognition/grounded_sam_recognition/grounded_sam_recognition.py b/butia_recognition/scripts/butia_recognition/grounded_sam_recognition/grounded_sam_recognition.py index 3c493ad..7b015a8 100755 --- a/butia_recognition/scripts/butia_recognition/grounded_sam_recognition/grounded_sam_recognition.py +++ b/butia_recognition/scripts/butia_recognition/grounded_sam_recognition/grounded_sam_recognition.py @@ -19,6 +19,10 @@ import supervision as sv from groundingdino.util.inference import Model from segment_anything import SamPredictor, sam_model_registry +from ram.models import ram +from ram import inference_ram +from ram import get_transform as get_transform_ram +from PIL import Image as PILImage torch.set_num_threads(1) @@ -51,11 +55,19 @@ def loadModel(self): sam = sam_model_registry[self.sam_model_type](checkpoint=f"{self.pkg_path}/config/sam_network_config/{self.sam_checkpoint}") self.sam_model = SamPredictor(sam) print('Done loading SAM model!') + if self.use_ram: + self.ram_model = ram(pretrained=f"{self.pkg_path}/config/ram_network_config/ram_swin_large_14m_no_optimizer.pth", vit="swin_l", image_size=384) + self.ram_model.eval() + self.ram_model = self.ram_model.to('cuda') + self.ram_transform = get_transform_ram(image_size=384) def unLoadModel(self): del self.dino_model if self.use_sam: del self.sam_model + if self.use_ram: + del self.ram_model + del self.ram_transform torch.cuda.empty_cache() @ifState @@ -74,12 +86,27 @@ def callback(self, *args): cv_img = ros_numpy.numpify(img) - results = self.dino_model.predict_with_classes(image=cv_img, classes=self.classes, box_threshold=self.box_threshold, text_threshold=self.text_threshold) + if self.use_ram: + ram_img = self.ram_transform(PILImage.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))).unsqueeze(0).to('cuda') + ram_results = inference_ram(ram_img, self.ram_model) + class_list = [class_name.strip() for class_name in ram_results[0].split('|')] + else: + class_list = self.classes + + print(class_list) + results = self.dino_model.predict_with_classes(image=cv_img, classes=class_list, box_threshold=self.box_threshold, text_threshold=self.text_threshold) results = results.with_nms(threshold=self.nms_threshold, class_agnostic=self.class_agnostic_nms) if len(results.class_id) > 0: self.sam_model.set_image(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) box_annotator = sv.BoxAnnotator() - debug_img = box_annotator.annotate(scene=cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB), detections=results, labels=[self.classes[idx] for idx in results.class_id]) + print(results.class_id) + labels = [] + for idx in results.class_id: + if idx is not None: + labels.append(class_list[idx]) + else: + labels.append('unknown') + debug_img = box_annotator.annotate(scene=cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB), detections=results, labels=labels) mask_annotator = sv.MaskAnnotator() objects_recognition = Recognitions2D() h = Header() @@ -97,16 +124,21 @@ def callback(self, *args): description_header.seq = 0 mask_arr = [] for i in range(len(results.class_id)): + if results.class_id[i] == None: + continue class_id = int(results.class_id[i]) - if class_id >= len(self.classes): + if class_id >= len(class_list): continue - label_class = self.classes[class_id] + label_class = class_list[class_id] - max_size = [0., 0., 0.] - if class_id < len(self.max_sizes): - max_size = self.max_sizes[class_id] + if not self.use_ram: + max_size = [0., 0., 0.] + if class_id < len(self.max_sizes): + max_size = self.max_sizes[class_id] + else: + max_size = [10., 10., 10.] description = Description2D() description.header = copy(description_header) @@ -144,7 +176,7 @@ def callback(self, *args): if label_class in value: index = j j += 1 - description.label = self.classes[index] + '/' + label_class if index is not None else label_class + description.label = class_list[index] + '/' + label_class if index is not None else label_class objects_recognition.descriptions.append(description) @@ -176,6 +208,8 @@ def readParameters(self): self.sam_model_type = rospy.get_param("~sam_model_type", "vit_tiny") self.sam_hq_token_only = rospy.get_param("~sam_hq_token_only", False) + self.use_ram = rospy.get_param("~use_ram", True) + self.classes_by_category = dict(rospy.get_param("~classes_by_category", {})) self.classes = rospy.get_param("~classes", [])