Skip to content

Commit

Permalink
Allow multiple object name input
Browse files Browse the repository at this point in the history
  • Loading branch information
Kanazawanaoaki committed Oct 15, 2023
1 parent 605a82a commit 89c9357
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions jsk_perception/docker/ofa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ def encode_text(self, text, length=None, append_bos=False, append_eos=False):
s = torch.cat([s, eos_item])
return s

def convert_objects_to_text(self, text):
if len(text) == 1:
object_text = text[0]
elif len(text) >= 2:
object_text = ', '.join(text[:-1]) + f' or {text[-1]}'
else:
object_text = ''
return object_text

def construct_sample(self, image, text):
if self.task_name == "caption" or self.task_name == "vqa_gen":
patch_image = self.patch_resize_transform(image).unsqueeze(0)
Expand Down Expand Up @@ -176,7 +185,8 @@ def construct_sample(self, image, text):
h_resize_ratio = torch.tensor(patch_image_size / h).unsqueeze(0)
patch_image = self.patch_resize_transform(image).unsqueeze(0)
patch_mask = torch.tensor([True])
src_text = self.encode_text(' which region does the text " {} " describe?'.format(text), append_bos=True,
object_text = self.convert_objects_to_text(text)
src_text = self.encode_text(' which region does the text " {} " describe?'.format(object_text), append_bos=True,
append_eos=True).unsqueeze(0)
src_length = torch.LongTensor([s.ne(self.pad_idx).long().sum() for s in src_text])
sample = {
Expand Down Expand Up @@ -225,10 +235,11 @@ def infer(self, img, text):
with torch.no_grad():
result, scores = eval_step(self.task, self.generator, self.models, sample)
results = {}
object_text = self.convert_objects_to_text(text)
for i in range(len(result)):
box = result[i]["box"]
logit = scores[i].item()
results[i] = {"box": box, "logit": logit, "phrase": text}
results[i] = {"box": box, "logit": logit, "phrase": object_text}

return results

Expand Down

0 comments on commit 89c9357

Please sign in to comment.