From 4dd5b4b26f61336c11001c1e150be5e5e272d076 Mon Sep 17 00:00:00 2001 From: zhuangqh Date: Thu, 12 Dec 2024 16:29:03 +1100 Subject: [PATCH] feat: add demoui for openai api --- charts/DemoUI/inference/README.md | 4 +- .../inference/templates/deployment.yaml | 21 ++++++-- charts/DemoUI/inference/values.yaml | 10 ++-- demo/inferenceUI/README.md | 4 +- demo/inferenceUI/chainlit_openai.py | 54 +++++++++++++++++++ .../{chainlit.py => chainlit_transformers.py} | 3 +- 6 files changed, 83 insertions(+), 13 deletions(-) create mode 100644 demo/inferenceUI/chainlit_openai.py rename demo/inferenceUI/{chainlit.py => chainlit_transformers.py} (93%) diff --git a/charts/DemoUI/inference/README.md b/charts/DemoUI/inference/README.md index ba314e6a9..4ff716f1f 100644 --- a/charts/DemoUI/inference/README.md +++ b/charts/DemoUI/inference/README.md @@ -5,12 +5,12 @@ Before deploying the Demo front-end, you must set the `workspaceServiceURL` envi To set this value, modify the `values.override.yaml` file or use the `--set` flag during Helm install/upgrade: ```bash -helm install inference-frontend ./charts/DemoUI/inference/values.yaml --set env.workspaceServiceURL="http://:80/chat" +helm install inference-frontend ./charts/DemoUI/inference --set env.workspaceServiceURL="http://:80" ``` Or through a custom `values` file (`values.override.yaml`): ```bash -helm install inference-frontend ./charts/DemoUI/inference/values.yaml -f values.override.yaml +helm install inference-frontend ./charts/DemoUI/inference -f values.override.yaml ``` ## Values diff --git a/charts/DemoUI/inference/templates/deployment.yaml b/charts/DemoUI/inference/templates/deployment.yaml index 71c50d383..483c50dbc 100644 --- a/charts/DemoUI/inference/templates/deployment.yaml +++ b/charts/DemoUI/inference/templates/deployment.yaml @@ -37,13 +37,26 @@ spec: args: - -c - | - mkdir -p /app/frontend && \ - pip install chainlit requests && \ - wget -O /app/frontend/inference.py https://raw.githubusercontent.com/kaito-project/kaito/main/demo/inferenceUI/chainlit.py && \ - chainlit run frontend/inference.py -w + mkdir -p /app/frontend + pip install chainlit pydantic==2.10.1 requests openai --quiet + case "$RUNTIME" in + vllm) + wget -O /app/frontend/inference.py https://raw.githubusercontent.com/kaito-project/kaito/refs/heads/main/demo/inferenceUI/chainlit_openai.py + ;; + transformers) + wget -O /app/frontend/inference.py https://raw.githubusercontent.com/kaito-project/kaito/refs/heads/main/demo/inferenceUI/chainlit_transformers.py + ;; + *) + echo "Error: Unsupported RUNTIME value" >&2 + exit 1 + ;; + esac + chainlit run --host 0.0.0.0 /app/frontend/inference.py -w env: - name: WORKSPACE_SERVICE_URL value: "{{ .Values.env.workspaceServiceURL }}" + - name: RUNTIME + value: "{{ .Values.env.runtime }}" workingDir: /app ports: - name: http diff --git a/charts/DemoUI/inference/values.yaml b/charts/DemoUI/inference/values.yaml index 3f925fbfe..77b4dca06 100644 --- a/charts/DemoUI/inference/values.yaml +++ b/charts/DemoUI/inference/values.yaml @@ -4,7 +4,7 @@ replicaCount: 1 image: repository: python pullPolicy: IfNotPresent - tag: "3.8" + tag: "3.12" imagePullSecrets: [] podAnnotations: {} serviceAccount: @@ -18,9 +18,9 @@ service: # Specify the URL for the Workspace Service inference endpoint. Use the DNS name within the cluster for reliability. # # Examples: - # Cluster IP: "http://:80/chat" - # DNS name: "http://..svc.cluster.local:80/chat" - # e.g., "http://workspace-falcon-7b.default.svc.cluster.local:80/chat" + # Cluster IP: "http://:80" + # DNS name: "http://..svc.cluster.local:80" + # e.g., "http://workspace-falcon-7b.default.svc.cluster.local:80" # # workspaceServiceURL: "" resources: @@ -44,6 +44,8 @@ readinessProbe: periodSeconds: 10 successThreshold: 1 timeoutSeconds: 1 +env: + runtime: "vllm" # "vllm" or "transformers" nodeSelector: {} tolerations: [] affinity: {} diff --git a/demo/inferenceUI/README.md b/demo/inferenceUI/README.md index f803dd72e..74d693def 100644 --- a/demo/inferenceUI/README.md +++ b/demo/inferenceUI/README.md @@ -20,12 +20,12 @@ Workspace Service endpoint. - Using the --set flag: ``` - helm install inference-frontend ./charts/DemoUI/inference --set env.workspaceServiceURL="http://..svc.cluster.local:80/chat" + helm install inference-frontend ./charts/DemoUI/inference --set env.workspaceServiceURL="http://..svc.cluster.local:80" ``` - Using a custom `values.override.yaml` file: ``` env: - workspaceServiceURL: "http://..svc.cluster.local:80/chat" + workspaceServiceURL: "http://..svc.cluster.local:80" ``` Then deploy with custom values file: ``` diff --git a/demo/inferenceUI/chainlit_openai.py b/demo/inferenceUI/chainlit_openai.py new file mode 100644 index 000000000..b17c7cd71 --- /dev/null +++ b/demo/inferenceUI/chainlit_openai.py @@ -0,0 +1,54 @@ +import os +from urllib.parse import urljoin + +from openai import AsyncOpenAI +import chainlit as cl + +URL = os.environ.get('WORKSPACE_SERVICE_URL') + +client = AsyncOpenAI(base_url=urljoin(URL, "v1"), api_key="YOUR_OPENAI_API_KEY") +cl.instrument_openai() + +settings = { + "temperature": 0.7, + "max_tokens": 500, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, +} + +@cl.on_chat_start +async def start_chat(): + models = await client.models.list() + print(f"Using model: {models}") + if len(models.data) == 0: + raise ValueError("No models found") + + global model + model = models.data[0].id + print(f"Using model: {model}") + +@cl.on_message +async def main(message: cl.Message): + messages=[ + { + "content": "You are a helpful assistant.", + "role": "system" + }, + { + "content": message.content, + "role": "user" + } + ] + msg = cl.Message(content="") + + stream = await client.chat.completions.create( + messages=messages, model=model, + stream=True, + **settings + ) + + async for part in stream: + if token := part.choices[0].delta.content or "": + await msg.stream_token(token) + await msg.update() diff --git a/demo/inferenceUI/chainlit.py b/demo/inferenceUI/chainlit_transformers.py similarity index 93% rename from demo/inferenceUI/chainlit.py rename to demo/inferenceUI/chainlit_transformers.py index 594369654..a906d44fb 100644 --- a/demo/inferenceUI/chainlit.py +++ b/demo/inferenceUI/chainlit_transformers.py @@ -1,4 +1,5 @@ import os +from urllib.parse import urljoin import chainlit as cl import requests @@ -25,7 +26,7 @@ def inference(prompt): } } - response = requests.post(URL, json=data) + response = requests.post(urljoin(URL, "chat"), json=data) if response.status_code == 200: response_data = response.json()