Skip to content

Commit

Permalink
Activate multimodality, Refactor llmWrapper
Browse files Browse the repository at this point in the history
llmWrappers are now subclasses of abstractLLMWrapper, so different llms can be used seamlessly.
  • Loading branch information
kimjammer committed Jun 12, 2024
1 parent bd62716 commit c7afcad
Show file tree
Hide file tree
Showing 12 changed files with 201 additions and 95 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ documentation [here](https://pytwitchapi.dev/en/stable/index.html#user-authentic
A virtual environment of some sort is recommended (Python 3.11 required); this project was developed with venv.

First, install the CUDA 11.8 version of pytorch 2.2.2.
`pip install torch==2.2.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118`
`pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118`

Install requirements.txt.

Expand Down
11 changes: 9 additions & 2 deletions constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,17 @@

MULTIMODAL_ENDPOINT = ""

MULTIMODAL_MODEL = "deepseek-vl"
MULTIMODAL_MODEL = "openbmb/MiniCPM-Llama3-V-2_5-int4"

MULTIMODAL_CONTEXT_SIZE = 8192

# This is the multimodal strategy (when to use multimodal/text only llm) that the program will start with.
# Runtime changes will not be saved here.
# Valid values are: "always", "never"
MULTIMODAL_STRATEGY = "always"

# This is the monitor index that screenshots will be taken. THIS IS NOT THE MONITOR NUMBER IN DISPLAY SETTINGS
PRIMARY_MONITOR = 2
PRIMARY_MONITOR = 0

# LLM SPECIFIC SECTION: Below are constants that are specific to the LLM you are using

Expand Down
87 changes: 28 additions & 59 deletions llmWrapper.py → llmWrappers/abstractLLMWrapper.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import os
import copy
import requests
import sseclient
import json
import time
from dotenv import load_dotenv
from transformers import AutoTokenizer
from constants import *
from modules.injection import Injection


class LLMWrapper:
class AbstractLLMWrapper:

def __init__(self, signals, tts, modules=None):
def __init__(self, signals, tts, llmState, modules=None):
self.signals = signals
self.llmState = llmState
self.tts = tts
self.blacklist = []
self.API = self.API(self)
if modules is None:
self.modules = {}
Expand All @@ -24,20 +22,18 @@ def __init__(self, signals, tts, modules=None):

self.headers = {"Content-Type": "application/json"}

self.enabled = True
self.next_cancelled = False

# Read in blacklist from file
with open('blacklist.txt', 'r') as file:
self.blacklist = file.read().splitlines()

load_dotenv()
self.tokenizer = AutoTokenizer.from_pretrained(MODEL, token=os.getenv("HF_TOKEN"))

#Below constants must be set by child classes
self.SYSTEM_PROMPT = None
self.LLM_ENDPOINT = None
self.CONTEXT_SIZE = None
self.tokenizer = None

# Basic filter to check if a message contains a word in the blacklist
def is_filtered(self, text):
# Filter messages with words in blacklist
if any(bad_word.lower() in text.lower().split() for bad_word in self.blacklist):
if any(bad_word.lower() in text.lower().split() for bad_word in self.llmState.blacklist):
return True
else:
return False
Expand All @@ -64,7 +60,6 @@ def assemble_injections(self, injections=None):
prompt += injection.text
return prompt

# This function is only used in completions mode
def generate_prompt(self):
messages = copy.deepcopy(self.signals.history)

Expand All @@ -82,7 +77,7 @@ def generate_prompt(self):

generation_prompt = AI_NAME + ": "

base_injections = [Injection(SYSTEM_PROMPT, 10), Injection(chat_section, 100)]
base_injections = [Injection(self.SYSTEM_PROMPT, 10), Injection(chat_section, 100)]
full_prompt = self.assemble_injections(base_injections) + generation_prompt
wrapper = [{"role": "user", "content": full_prompt}]

Expand All @@ -92,7 +87,7 @@ def generate_prompt(self):
# print(prompt_tokens)

# Maximum 90% context size usage before prompting LLM
if prompt_tokens < 0.9 * CONTEXT_SIZE:
if prompt_tokens < 0.9 * self.CONTEXT_SIZE:
self.signals.sio_queue.put(("full_prompt", full_prompt))
# print(full_prompt)
return full_prompt
Expand All @@ -105,62 +100,36 @@ def generate_prompt(self):
messages.pop(0)
print("Prompt too long, removing earliest message")

def prepare_payload(self):
raise NotImplementedError("Must implement prepare_payload in child classes")

def prompt(self):
if not self.enabled:
if not self.llmState.enabled:
return

self.signals.AI_thinking = True
self.signals.new_message = False
self.signals.sio_queue.put(("reset_next_message", None))

data = {
"mode": "instruct",
"stream": True,
"max_tokens": 200,
"skip_special_tokens": False, # Necessary for Llama 3
"custom_token_bans": BANNED_TOKENS,
"stop": STOP_STRINGS,
"messages": [{
"role": "user",
"content": self.generate_prompt()
}]
}

# Currently unused
if "multimodal" in self.modules:
if self.modules["multimodal"].API.multimodal_now():
data["messages"][0] = {
"role": "user",
"content": [
{
"type": "text",
"data": data["messages"][0]["content"]
},
{
"type": "image_url",
# OpenAI uses "url", "image_url" is for lmdeploy
"image_url": f"data:image/jpeg;base64,{self.modules['multimodal'].API.screen_shot()}"
}
]
}

stream_response = requests.post(LLM_ENDPOINT + "/v1/chat/completions", headers=self.headers, json=data,
data = self.prepare_payload()

stream_response = requests.post(self.LLM_ENDPOINT + "/v1/chat/completions", headers=self.headers, json=data,
verify=False, stream=True)
response_stream = sseclient.SSEClient(stream_response)

AI_message = ''
for event in response_stream.events():
# Check to see if next message was canceled
if self.next_cancelled:
if self.llmState.next_cancelled:
continue

payload = json.loads(event.data)
chunk = payload['choices'][0]['delta']['content']
AI_message += chunk
self.signals.sio_queue.put(("next_chunk", chunk))

if self.next_cancelled:
self.next_cancelled = False
if self.llmState.next_cancelled:
self.llmState.next_cancelled = False
self.signals.sio_queue.put(("reset_next_message", None))
self.signals.AI_thinking = False
return
Expand All @@ -183,10 +152,10 @@ def __init__(self, outer):
self.outer = outer

def get_blacklist(self):
return self.outer.blacklist
return self.outer.llmState.blacklist

def set_blacklist(self, new_blacklist):
self.outer.blacklist = new_blacklist
self.outer.llmState.blacklist = new_blacklist
with open('blacklist.txt', 'w') as file:
for word in new_blacklist:
file.write(word + "\n")
Expand All @@ -195,15 +164,15 @@ def set_blacklist(self, new_blacklist):
self.outer.signals.sio_queue.put(('get_blacklist', new_blacklist))

def set_LLM_status(self, status):
self.outer.enabled = status
self.outer.llmState.enabled = status
if status:
self.outer.signals.AI_thinking = False
self.outer.signals.sio_queue.put(('LLM_status', status))

def get_LLM_status(self):
return self.outer.enabled
return self.outer.llmState.enabled

def cancel_next(self):
self.outer.next_cancelled = True
self.outer.llmState.next_cancelled = True
# For text-generation-webui: Immediately stop generation
requests.post(LLM_ENDPOINT + "/v1/internal/stop-generation", headers={"Content-Type": "application/json"})
requests.post(self.outer.LLM_ENDPOINT + "/v1/internal/stop-generation", headers={"Content-Type": "application/json"})
59 changes: 59 additions & 0 deletions llmWrappers/imageLLMWrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import mss, cv2, base64
import numpy as np
from transformers import AutoTokenizer
from constants import *
from llmWrappers.abstractLLMWrapper import AbstractLLMWrapper


class ImageLLMWrapper(AbstractLLMWrapper):

def __init__(self, signals, tts, llmState, modules=None):
super().__init__(signals, tts, llmState, modules)
self.SYSTEM_PROMPT = SYSTEM_PROMPT
self.LLM_ENDPOINT = MULTIMODAL_ENDPOINT
self.CONTEXT_SIZE = MULTIMODAL_CONTEXT_SIZE
self.tokenizer = AutoTokenizer.from_pretrained(MULTIMODAL_MODEL, token=os.getenv("HF_TOKEN"), trust_remote_code=True)

self.MSS = None

def screen_shot(self):
if self.MSS is None:
self.MSS = mss.mss()

# Take a screenshot of the main screen
frame_bytes = self.MSS.grab(self.MSS.monitors[PRIMARY_MONITOR])

frame_array = np.array(frame_bytes)
# resize
frame_resized = cv2.resize(frame_array, (1920, 1080), interpolation=cv2.INTER_CUBIC)
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 95]
result, frame_encoded = cv2.imencode('.jpg', frame_resized, encode_param)
# base64
frame_base64 = base64.b64encode(frame_encoded).decode("utf-8")
return frame_base64

def prepare_payload(self):
return {
"mode": "instruct",
"stream": True,
"max_tokens": 200,
"skip_special_tokens": False, # Necessary for Llama 3
"custom_token_bans": BANNED_TOKENS,
"stop": STOP_STRINGS,
"messages": [{
"role": "user",
"content": [
{
"type": "text",
"text": self.generate_prompt()
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{self.screen_shot()}"
}
}
]
}]
}
8 changes: 8 additions & 0 deletions llmWrappers/llmState.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class LLMState:
def __init__(self):
self.enabled = True
self.next_cancelled = False

# Read in blacklist from file
with open('blacklist.txt', 'r') as file:
self.blacklist = file.read().splitlines()
29 changes: 29 additions & 0 deletions llmWrappers/textLLMWrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

from transformers import AutoTokenizer
from constants import *
from llmWrappers.abstractLLMWrapper import AbstractLLMWrapper


class TextLLMWrapper(AbstractLLMWrapper):

def __init__(self, signals, tts, llmState, modules=None):
super().__init__(signals, tts, llmState, modules)
self.SYSTEM_PROMPT = SYSTEM_PROMPT
self.LLM_ENDPOINT = LLM_ENDPOINT
self.CONTEXT_SIZE = CONTEXT_SIZE
self.tokenizer = AutoTokenizer.from_pretrained(MODEL, token=os.getenv("HF_TOKEN"))

def prepare_payload(self):
return {
"mode": "instruct",
"stream": True,
"max_tokens": 200,
"skip_special_tokens": False, # Necessary for Llama 3
"custom_token_bans": BANNED_TOKENS,
"stop": STOP_STRINGS,
"messages": [{
"role": "user",
"content": self.generate_prompt()
}]
}
25 changes: 16 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
# Class Imports
from signals import Signals
from prompter import Prompter
from llmWrapper import LLMWrapper
from llmWrappers.llmState import LLMState
from llmWrappers.textLLMWrapper import TextLLMWrapper
from llmWrappers.imageLLMWrapper import ImageLLMWrapper
from stt import STT
from tts import TTS
from modules.twitchClient import TwitchClient
from modules.audioPlayer import AudioPlayer
from modules.vtubeStudio import VtubeStudio
# from modules.multimodal import MultiModal
from modules.multimodal import MultiModal
from modules.customPrompt import CustomPrompt
from modules.memory import Memory
from socketioServer import SocketIOServer
Expand Down Expand Up @@ -46,28 +48,33 @@ def signal_handler(sig, frame):
stt = STT(signals)
# Create TTS
tts = TTS(signals)
# Create LLMWrapper
llm_wrapper = LLMWrapper(signals, tts, modules)
# Create LLMWrappers
llmState = LLMState()
llms = {
"text": TextLLMWrapper(signals, tts, llmState, modules),
"image": ImageLLMWrapper(signals, tts, llmState, modules)
}
# Create Prompter
prompter = Prompter(signals, llm_wrapper)
prompter = Prompter(signals, llms, modules)

# Create Discord bot
# modules['discord'] = DiscordClient(signals, stt, enabled=False)
# Create Twitch bot
modules['twitch'] = TwitchClient(signals, enabled=True)
modules['twitch'] = TwitchClient(signals, enabled=False)
# Create audio player
modules['audio_player'] = AudioPlayer(signals, enabled=True)
# Create Vtube Studio plugin
modules['vtube_studio'] = VtubeStudio(signals, enabled=True)
# Create Multimodal module (Currently no suitable models have been found/created)
# modules['multimodal'] = MultiModal(signals, enabled=False)
# Create Multimodal module
modules['multimodal'] = MultiModal(signals, enabled=False)
# Create Custom Prompt module
modules['custom_prompt'] = CustomPrompt(signals, enabled=True)
# Create Memory module
modules['memory'] = Memory(signals, enabled=True)

# Create Socket.io server
sio = SocketIOServer(signals, stt, tts, llm_wrapper, prompter, modules=modules)
# The specific llmWrapper it gets doesn't matter since state is shared between all llmWrappers
sio = SocketIOServer(signals, stt, tts, llms["text"], prompter, modules=modules)

# Create threads (As daemons, so they exit when the main thread exits)
prompter_thread = threading.Thread(target=prompter.prompt_loop, daemon=True)
Expand Down
Loading

0 comments on commit c7afcad

Please sign in to comment.