-
-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Activate multimodality, Refactor llmWrapper
llmWrappers are now subclasses of abstractLLMWrapper, so different llms can be used seamlessly.
- Loading branch information
Showing
12 changed files
with
201 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
import mss, cv2, base64 | ||
import numpy as np | ||
from transformers import AutoTokenizer | ||
from constants import * | ||
from llmWrappers.abstractLLMWrapper import AbstractLLMWrapper | ||
|
||
|
||
class ImageLLMWrapper(AbstractLLMWrapper): | ||
|
||
def __init__(self, signals, tts, llmState, modules=None): | ||
super().__init__(signals, tts, llmState, modules) | ||
self.SYSTEM_PROMPT = SYSTEM_PROMPT | ||
self.LLM_ENDPOINT = MULTIMODAL_ENDPOINT | ||
self.CONTEXT_SIZE = MULTIMODAL_CONTEXT_SIZE | ||
self.tokenizer = AutoTokenizer.from_pretrained(MULTIMODAL_MODEL, token=os.getenv("HF_TOKEN"), trust_remote_code=True) | ||
|
||
self.MSS = None | ||
|
||
def screen_shot(self): | ||
if self.MSS is None: | ||
self.MSS = mss.mss() | ||
|
||
# Take a screenshot of the main screen | ||
frame_bytes = self.MSS.grab(self.MSS.monitors[PRIMARY_MONITOR]) | ||
|
||
frame_array = np.array(frame_bytes) | ||
# resize | ||
frame_resized = cv2.resize(frame_array, (1920, 1080), interpolation=cv2.INTER_CUBIC) | ||
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 95] | ||
result, frame_encoded = cv2.imencode('.jpg', frame_resized, encode_param) | ||
# base64 | ||
frame_base64 = base64.b64encode(frame_encoded).decode("utf-8") | ||
return frame_base64 | ||
|
||
def prepare_payload(self): | ||
return { | ||
"mode": "instruct", | ||
"stream": True, | ||
"max_tokens": 200, | ||
"skip_special_tokens": False, # Necessary for Llama 3 | ||
"custom_token_bans": BANNED_TOKENS, | ||
"stop": STOP_STRINGS, | ||
"messages": [{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": self.generate_prompt() | ||
}, | ||
{ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/jpeg;base64,{self.screen_shot()}" | ||
} | ||
} | ||
] | ||
}] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
class LLMState: | ||
def __init__(self): | ||
self.enabled = True | ||
self.next_cancelled = False | ||
|
||
# Read in blacklist from file | ||
with open('blacklist.txt', 'r') as file: | ||
self.blacklist = file.read().splitlines() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import os | ||
|
||
from transformers import AutoTokenizer | ||
from constants import * | ||
from llmWrappers.abstractLLMWrapper import AbstractLLMWrapper | ||
|
||
|
||
class TextLLMWrapper(AbstractLLMWrapper): | ||
|
||
def __init__(self, signals, tts, llmState, modules=None): | ||
super().__init__(signals, tts, llmState, modules) | ||
self.SYSTEM_PROMPT = SYSTEM_PROMPT | ||
self.LLM_ENDPOINT = LLM_ENDPOINT | ||
self.CONTEXT_SIZE = CONTEXT_SIZE | ||
self.tokenizer = AutoTokenizer.from_pretrained(MODEL, token=os.getenv("HF_TOKEN")) | ||
|
||
def prepare_payload(self): | ||
return { | ||
"mode": "instruct", | ||
"stream": True, | ||
"max_tokens": 200, | ||
"skip_special_tokens": False, # Necessary for Llama 3 | ||
"custom_token_bans": BANNED_TOKENS, | ||
"stop": STOP_STRINGS, | ||
"messages": [{ | ||
"role": "user", | ||
"content": self.generate_prompt() | ||
}] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.