Skip to content

Commit

Permalink
Merge branch 'feature/vlm-recognition' of https://github.com/butia-bo…
Browse files Browse the repository at this point in the history
…ts/butia_vision into feature/vlm-recognition
  • Loading branch information
crislmfroes committed Aug 12, 2024
2 parents 99e4dc7 + eab80e0 commit ba6cf25
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
Install [this branch](https://github.com/butia-bots/butia_vision_msgs/tree/feature/gpsr-recognition) of butia_vision_msgs. Then run the following commands on the jetson, and make sure the pre-installed version of pytorch, numpy and other libraries from JetPack SDK is kept frozen and not updated during the install process.

```sh
pip install inference supervision transformers accelerate peft
pip install transformers accelerate peft bitsandbytes supervision
```
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import os
from copy import copy
import cv2
from inference.models.paligemma import PaliGemma
from inference.models.sam import SegmentAnything
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
from transformers import SamModel, SamProcessor
from std_msgs.msg import Header
from sensor_msgs.msg import Image
from geometry_msgs.msg import Vector3
Expand All @@ -29,6 +29,8 @@ def __init__(self, state=True):

self.colors = dict([(k, np.random.randint(low=0, high=256, size=(3,)).tolist()) for k in self.classes])

self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

self.loadModel()
self.initRosComm()

Expand All @@ -45,9 +47,9 @@ def serverSetClass(self, req):
return SetClassResponse()

def serverVisualQuestionAnswering(self, req):
result = self.model.predict(image_in=self.cv_img, prompt=res.question)
result = self.inferPaliGemma(image=PIL.Image.fromarray(cv2.cvtColor(self.cv_img, cv2.COLOR_BGR2RGB)), prompt=req.question)
res = VisualQuestionAnsweringResponse()
res.answer = result[0]
res.answer = result
return res

def serverStart(self, req):
Expand All @@ -59,16 +61,33 @@ def serverStop(self, req):
return super().serverStop(req)

def loadModel(self):
self.model = PaliGemma(model_id='paligemma-3b-mix-224')
self.sam = SegmentAnything()
self.pg = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-mix-224').to(self.device)
self.pg_processor = PaliGemmaProcessor.from_pretrained('google/paligemma-3b-mix-224')
self.sam = SamModel.from_pretrained('facebook/sam-vit-base').to(self.device)
self.sam_processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
print('Done loading model!')

def unLoadModel(self):
del self.model
del self.pg
del self.sam
gc.collect()
torch.cuda.empty_cache()
self.model = None
self.pg = None
self.sam = None

def inferPaliGemma(self, image, prompt):
inputs = self.pg_processor(text=prompt, images=image, return_tensors="pt").to(self.device)
with torch.inference_mode():
outputs = self.pg.generate(**inputs, max_new_tokens=50, do_sample=False)
result = self.pg_processor.batch_decode(outputs, skip_special_tokens=True)
return result[0][len(prompt):].lstrip('\n')

def inferSam(self, image, input_boxes):
inputs = self.sam_processor(images=image, input_boxes=input_boxes, return_tensors="pt").to(self.device)
with torch.inference_mode():
outputs = self.sam(**inputs)
masks = self.sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
return masks[0].detach().cpu().numpy()

@ifState
def callback(self, *args):
Expand All @@ -95,17 +114,14 @@ def callback(self, *args):
description_header = img_rgb.header
description_header.seq = 0

results = self.model.predict(image_in=PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)), prompt=f"detect " + " ; ".join(self.all_classes))[0]
boxes_ = sv.Detections.from_lmm(lmm='paligemma', result=results[0], resolution_wh=(cv_img.shape[1], cv_img.shape[0]), classes=self.all_classes)
results = self.inferPaliGemma(image=PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)), prompt=f"detect " + " ; ".join(self.all_classes))
boxes_ = sv.Detections.from_lmm(lmm='paligemma', result=results, resolution_wh=(cv_img.shape[1], cv_img.shape[0]), classes=self.all_classes)
debug_img = cv_img
masks = []
embeddings = self.sam.embed_image(image=cv_img)[0]
for x1, y1, x2, y2 in boxes_.xyxy:
center_x = (x1 + x2)//2
center_y = (y1 + y2)//2
masks.append(self.sam.segment_image(image=cv_img, embeddings=embeddings, point_labels=[1], point_coords=[[center_x, center_y]])[0])
boxes_.mask = np.array(masks)
masks.append(self.inferSam(image=PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)), input_boxes=[[[x1, y1, x2, y2]]])[:,0,:,:])
if len(boxes_):
boxes_.mask = np.array(masks).reshape((len(masks), cv_img.shape[0], cv_img.shape[1]))
for i in range(len(boxes_)):
box = boxes_[i]
xyxy_box = list(boxes_[i].xyxy.astype(int)[0])
Expand All @@ -127,7 +143,7 @@ def callback(self, *args):
description.bbox.center.y = int(xyxy_box[1]) + int(size[1]/2)
description.bbox.size_x = size[0]
description.bbox.size_y = size[1]
description.mask = ros_numpy.msgify(Image, boxes_.mask[i])
description.mask = ros_numpy.msgify(Image, (boxes_.mask[i]*255).astype(np.uint8), encoding='mono8')

if ('people' in self.all_classes and label_class in self.classes_by_category['people'] or 'people' in self.all_classes and label_class == 'people'):

Expand Down

0 comments on commit ba6cf25

Please sign in to comment.