Skip to content

Commit

Permalink
Add OFA detection
Browse files Browse the repository at this point in the history
  • Loading branch information
Kanazawanaoaki committed Oct 15, 2023
1 parent 9340be1 commit cb948d9
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
58 changes: 54 additions & 4 deletions jsk_perception/docker/ofa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, task, model_scale):
utils.split_paths(param_path),
arg_overrides=overrides)
elif task == "refcoco":
tasks.register_task(self.task, RefcocoTask)
tasks.register_task(task, RefcocoTask)
self.models, self.cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(param_path),
arg_overrides=overrides)
Expand Down Expand Up @@ -140,6 +140,15 @@ def encode_text(self, text, length=None, append_bos=False, append_eos=False):
s = torch.cat([s, eos_item])
return s

def convert_objects_to_text(self, text):
if len(text) == 1:
object_text = text[0]
elif len(text) >= 2:
object_text = ', '.join(text[:-1]) + f' or {text[-1]}'
else:
object_text = ''
return object_text

def construct_sample(self, image, text):
if self.task_name == "caption" or self.task_name == "vqa_gen":
patch_image = self.patch_resize_transform(image).unsqueeze(0)
Expand Down Expand Up @@ -176,7 +185,8 @@ def construct_sample(self, image, text):
h_resize_ratio = torch.tensor(patch_image_size / h).unsqueeze(0)
patch_image = self.patch_resize_transform(image).unsqueeze(0)
patch_mask = torch.tensor([True])
src_text = self.encode_text(' which region does the text " {} " describe?'.format(text), append_bos=True,
object_text = self.convert_objects_to_text(text)
src_text = self.encode_text(' which region does the text " {} " describe?'.format(object_text), append_bos=True,
append_eos=True).unsqueeze(0)
src_length = torch.LongTensor([s.ne(self.pad_idx).long().sum() for s in src_text])
sample = {
Expand Down Expand Up @@ -214,7 +224,24 @@ def infer(self, img, text):
text = result[0]['answer']
return text
elif self.task_name == "refcoco":
pass
# image = cv2.resize(img, dsize=(640, 480)) # NOTE forcely
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
# Construct input sample & preprocess for GPU if cuda available for VG
sample = self.construct_sample(image, text)
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
sample = utils.apply_to_sample(apply_half, sample) if self.use_fp16 else sample
with torch.no_grad():
result, scores = eval_step(self.task, self.generator, self.models, sample)
results = {}
object_text = self.convert_objects_to_text(text)
for i in range(len(result)):
box = result[i]["box"]
logit = scores[i].item()
results[i] = {"box": box, "logit": logit, "phrase": object_text}

return results

# run
if __name__ == "__main__":
Expand All @@ -232,6 +259,9 @@ def infer(self, img, text):
elif ofa_task == "vqa_gen":
vqa_infer = Inference("vqa_gen", ofa_model_scale)

elif ofa_task == "detection":
detection_infer = Inference("refcoco", ofa_model_scale)

else:
raise RuntimeError("No application is available")

Expand Down Expand Up @@ -274,5 +304,25 @@ def vqa_request():
return Response(response=json.dumps({"results": results}), status=200)
except NameError:
print("Skipping create vqa_gen app")


try:
@app.route("/detection", methods=['POST'])
def detection_request():
data = request.data.decode("utf-8")
data_json = json.loads(data)
# process image
image_b = data_json['image']
image_dec = base64.b64decode(image_b)
data_np = np.fromstring(image_dec, dtype='uint8')
img = cv2.imdecode(data_np, 1)
# get text
texts = data_json['queries']
infer_results = detection_infer.infer(img, texts)
results = []
for i in range(len(infer_results)):
results.append({"id": i, "box": infer_results[i]["box"], "logit": infer_results[i]["logit"], "phrase": infer_results[i]["phrase"]})
return Response(response=json.dumps({"results": results}), status=200)
except NameError:
print("Skipping create detection app")

app.run("0.0.0.0", 8080, threaded=True)
3 changes: 2 additions & 1 deletion jsk_perception/src/jsk_perception/vil_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def __init__(self):
DetectionTaskFeedback,
DetectionTaskResult,
"detection")
self.model_name = rospy.get_param("~model", default="dino")
self.pub_class = rospy.Publisher('~class', ClassificationResult, queue_size=1)
self.pub_rects = rospy.Publisher('~rects', RectArray, queue_size=1)
self.pub_image = rospy.Publisher('~output/image', Image, queue_size=1)
Expand Down Expand Up @@ -255,7 +256,7 @@ def inference(self, img_msg, queries):
classification_msg.label_names = labels
classification_msg.label_proba = scores # cosine similarities
classification_msg.probabilities = scores # sum(probabilities) is 1
classification_msg.classifier = 'dino'
classification_msg.classifier = self.model_name
classification_msg.target_names = queries
self.pub_class.publish(classification_msg)

Expand Down

0 comments on commit cb948d9

Please sign in to comment.