Skip to content

Commit

Permalink
Fixes bug with image query (#40)
Browse files Browse the repository at this point in the history
Signed-off-by: Harsha Ramayanam <[email protected]>
  • Loading branch information
HarshaRamayanam authored Jan 10, 2025
1 parent fda37f4 commit 675735e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
13 changes: 6 additions & 7 deletions MultimodalQnA/ui/gradio/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class Conversation:
split_video: str = None
image: str = None
audio_query_file: str = None
image_query_file: str = None
pdf: str = None

def _template_caption(self):
Expand All @@ -49,9 +48,9 @@ def get_prompt(self):
# Need to do RAG. If the query is text, prompt is the query only
if self.audio_query_file:
ret = [{"role": "user", "content": [{"type": "audio", "audio": self.get_b64_audio_query()}]}]
elif self.image_query_file:
b64_image = get_b64_frame_from_timestamp(self.image_query_file, 0)
ret = [{"role": "user", "content": [{"type": "text", "text": self.messages[0][1]},{"type": "image_url", "image_url": {"url": b64_image}}]}]
elif len(messages) in self.image_query_files:
b64_image = get_b64_frame_from_timestamp(self.image_query_files[len(messages)], 0)
ret = [{"role": "user", "content": [{"type": "text", "text": messages[0][1]},{"type": "image_url", "image_url": {"url": b64_image}}]}]
else:
ret = messages[0][1]
else:
Expand Down Expand Up @@ -80,8 +79,8 @@ def get_prompt(self):
content[0]["text"] = content[0]["text"] + " " + self._template_caption()
content.append({"type": "image_url", "image_url": {"url": base64_frame}})
# There might be a query image
if self.image_query_file:
content.append({"type": "image_url", "image_url": {"url": self.image_query_file}})
if i+2 in self.image_query_files:
content.append({"type": "image_url", "image_url": {"url": get_b64_frame_from_timestamp(self.image_query_files[i+2], 0)}})
dic["content"] = content
conv_dict.append(dic)
else:
Expand Down Expand Up @@ -132,7 +131,7 @@ def to_gradio_chatbot(self):
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>", "").strip()
ret.append([msg, None])
elif self.image_query_file:
elif i in self.image_query_files:
import base64
from io import BytesIO

Expand Down
3 changes: 1 addition & 2 deletions MultimodalQnA/ui/gradio/multimodalqna_ui_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def add_text(state, textbox, audio, request: gr.Request):
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), None, None) + (disable_btn,) * 1
# If it is a image query
elif textbox['files']:
image_file = textbox['files'][0]
state.image_query_file = image_file
state.image_query_files[len(state.messages)] = image_file
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
Expand Down Expand Up @@ -109,7 +109,6 @@ def http_bot(state, request: gr.Request):
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
new_state.audio_query_file = state.audio_query_file
new_state.image_query_file = state.image_query_file
new_state.image_query_files = state.image_query_files
state = new_state

Expand Down

0 comments on commit 675735e

Please sign in to comment.