diff --git a/butia_recognition/scripts/butia_recognition/paligemma_recognition/README.md b/butia_recognition/scripts/butia_recognition/paligemma_recognition/README.md index 92ce27f..73c7ca7 100644 --- a/butia_recognition/scripts/butia_recognition/paligemma_recognition/README.md +++ b/butia_recognition/scripts/butia_recognition/paligemma_recognition/README.md @@ -3,5 +3,5 @@ Install [this branch](https://github.com/butia-bots/butia_vision_msgs/tree/feature/gpsr-recognition) of butia_vision_msgs. Then run the following commands on the jetson, and make sure the pre-installed version of pytorch, numpy and other libraries from JetPack SDK is kept frozen and not updated during the install process. ```sh -pip install inference supervision transformers accelerate peft +pip install transformers accelerate peft bitsandbytes supervision ``` \ No newline at end of file diff --git a/butia_recognition/scripts/butia_recognition/paligemma_recognition/paligemma_recognition.py b/butia_recognition/scripts/butia_recognition/paligemma_recognition/paligemma_recognition.py index 85e9910..c7837f9 100755 --- a/butia_recognition/scripts/butia_recognition/paligemma_recognition/paligemma_recognition.py +++ b/butia_recognition/scripts/butia_recognition/paligemma_recognition/paligemma_recognition.py @@ -7,8 +7,8 @@ import os from copy import copy import cv2 -from inference.models.paligemma import PaliGemma -from inference.models.sam import SegmentAnything +from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor +from transformers import SamModel, SamProcessor from std_msgs.msg import Header from sensor_msgs.msg import Image from geometry_msgs.msg import Vector3 @@ -29,6 +29,8 @@ def __init__(self, state=True): self.colors = dict([(k, np.random.randint(low=0, high=256, size=(3,)).tolist()) for k in self.classes]) + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.loadModel() self.initRosComm() @@ -45,9 +47,9 @@ def serverSetClass(self, req): return SetClassResponse() def serverVisualQuestionAnswering(self, req): - result = self.model.predict(image_in=self.cv_img, prompt=res.question) + result = self.inferPaliGemma(image=PIL.Image.fromarray(cv2.cvtColor(self.cv_img, cv2.COLOR_BGR2RGB)), prompt=req.question) res = VisualQuestionAnsweringResponse() - res.answer = result[0] + res.answer = result return res def serverStart(self, req): @@ -59,16 +61,33 @@ def serverStop(self, req): return super().serverStop(req) def loadModel(self): - self.model = PaliGemma(model_id='paligemma-3b-mix-224') - self.sam = SegmentAnything() + self.pg = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-mix-224').to(self.device) + self.pg_processor = PaliGemmaProcessor.from_pretrained('google/paligemma-3b-mix-224') + self.sam = SamModel.from_pretrained('facebook/sam-vit-base').to(self.device) + self.sam_processor = SamProcessor.from_pretrained('facebook/sam-vit-base') print('Done loading model!') def unLoadModel(self): - del self.model + del self.pg del self.sam gc.collect() torch.cuda.empty_cache() - self.model = None + self.pg = None + self.sam = None + + def inferPaliGemma(self, image, prompt): + inputs = self.pg_processor(text=prompt, images=image, return_tensors="pt").to(self.device) + with torch.inference_mode(): + outputs = self.pg.generate(**inputs, max_new_tokens=50, do_sample=False) + result = self.pg_processor.batch_decode(outputs, skip_special_tokens=True) + return result[0][len(prompt):].lstrip('\n') + + def inferSam(self, image, input_boxes): + inputs = self.sam_processor(images=image, input_boxes=input_boxes, return_tensors="pt").to(self.device) + with torch.inference_mode(): + outputs = self.sam(**inputs) + masks = self.sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) + return masks[0].detach().cpu().numpy() @ifState def callback(self, *args): @@ -95,17 +114,14 @@ def callback(self, *args): description_header = img_rgb.header description_header.seq = 0 - results = self.model.predict(image_in=PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)), prompt=f"detect " + " ; ".join(self.all_classes))[0] - boxes_ = sv.Detections.from_lmm(lmm='paligemma', result=results[0], resolution_wh=(cv_img.shape[1], cv_img.shape[0]), classes=self.all_classes) + results = self.inferPaliGemma(image=PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)), prompt=f"detect " + " ; ".join(self.all_classes)) + boxes_ = sv.Detections.from_lmm(lmm='paligemma', result=results, resolution_wh=(cv_img.shape[1], cv_img.shape[0]), classes=self.all_classes) debug_img = cv_img masks = [] - embeddings = self.sam.embed_image(image=cv_img)[0] for x1, y1, x2, y2 in boxes_.xyxy: - center_x = (x1 + x2)//2 - center_y = (y1 + y2)//2 - masks.append(self.sam.segment_image(image=cv_img, embeddings=embeddings, point_labels=[1], point_coords=[[center_x, center_y]])[0]) - boxes_.mask = np.array(masks) + masks.append(self.inferSam(image=PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)), input_boxes=[[[x1, y1, x2, y2]]])[:,0,:,:]) if len(boxes_): + boxes_.mask = np.array(masks).reshape((len(masks), cv_img.shape[0], cv_img.shape[1])) for i in range(len(boxes_)): box = boxes_[i] xyxy_box = list(boxes_[i].xyxy.astype(int)[0]) @@ -127,7 +143,7 @@ def callback(self, *args): description.bbox.center.y = int(xyxy_box[1]) + int(size[1]/2) description.bbox.size_x = size[0] description.bbox.size_y = size[1] - description.mask = ros_numpy.msgify(Image, boxes_.mask[i]) + description.mask = ros_numpy.msgify(Image, (boxes_.mask[i]*255).astype(np.uint8), encoding='mono8') if ('people' in self.all_classes and label_class in self.classes_by_category['people'] or 'people' in self.all_classes and label_class == 'people'):