From 35a32dc06fc0c97bd52a695dd2eefaf11e9862bb Mon Sep 17 00:00:00 2001 From: eric2788 Date: Wed, 23 Oct 2024 18:12:16 +0800 Subject: [PATCH] added web-llm to llm provider --- package.json | 2 + pnpm-lock.yaml | 24 ++++++++ src/contents/index/mounter.tsx | 2 +- src/hooks/life-cycle.ts | 30 ++++++++++ src/llms/cloudflare-ai.ts | 2 +- src/llms/gemini-nano.ts | 8 +-- src/llms/index.ts | 8 ++- src/llms/models.ts | 33 +++++++++++ src/llms/remote-worker.ts | 2 +- src/llms/web-llm.ts | 87 +++++++++++++++++++++++++++++ src/options/components/Selector.tsx | 3 +- src/options/fragments/llm.tsx | 82 +++++++++++++++++---------- tests/units/llm.spec.ts | 64 +++++++++++++++++++++ tsconfig.json | 3 + 14 files changed, 311 insertions(+), 39 deletions(-) create mode 100644 src/llms/models.ts create mode 100644 src/llms/web-llm.ts diff --git a/package.json b/package.json index f766ba57..90d49282 100644 --- a/package.json +++ b/package.json @@ -27,6 +27,7 @@ "@ffmpeg/ffmpeg": "^0.12.10", "@ffmpeg/util": "^0.12.1", "@material-tailwind/react": "^2.1.9", + "@mlc-ai/web-llm": "^0.2.73", "@plasmohq/messaging": "^0.6.2", "@plasmohq/storage": "^1.9.3", "@react-hooks-library/core": "^0.5.2", @@ -64,6 +65,7 @@ "@types/react": "18.2.37", "@types/react-dom": "18.2.15", "@types/semver": "^7.5.8", + "@webgpu/types": "^0.1.49", "dotenv": "^16.4.5", "esbuild": "^0.20.2", "gify-parse": "^1.0.7", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a98c6037..caa0b650 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -29,6 +29,9 @@ importers: '@material-tailwind/react': specifier: ^2.1.9 version: 2.1.9(react-dom@18.2.0(react@18.2.0))(react@18.2.0) + '@mlc-ai/web-llm': + specifier: ^0.2.73 + version: 0.2.73 '@plasmohq/messaging': specifier: ^0.6.2 version: 0.6.2(react@18.2.0) @@ -135,6 +138,9 @@ importers: '@types/semver': specifier: ^7.5.8 version: 7.5.8 + '@webgpu/types': + specifier: ^0.1.49 + version: 0.1.49 dotenv: specifier: ^16.4.5 version: 16.4.5 @@ -856,6 +862,9 @@ packages: resolution: {integrity: sha512-iA7+tyVqfrATAIsIRWQG+a7ZLLD0VaOCKV2Wd/v4mqIU3J9c4jx9p7S0nw1XH3gJCKNBOOwACOPYYSUu9pgT+w==} engines: {node: '>=12.0.0'} + '@mlc-ai/web-llm@0.2.73': + resolution: {integrity: sha512-v9eS05aptZYNBhRZ7KEq5dr1kMESHvkdwru9tNY1oS1GDDEBjUt4Y/u06ejvP3qF4VatCakEezzJT3SkuMhDnQ==} + '@motionone/animation@10.17.0': resolution: {integrity: sha512-ANfIN9+iq1kGgsZxs+Nz96uiNcPLGTXwfNo2Xz/fcJXniPYpaz/Uyrfa+7I5BPLxCP82sh7quVDudf1GABqHbg==} @@ -1946,6 +1955,9 @@ packages: '@vue/shared@3.3.4': resolution: {integrity: sha512-7OjdcV8vQ74eiz1TZLzZP4JwqM5fA94K6yntPS5Z25r9HDuGNzaGdgvwKYq6S+MxwF0TFRwe50fIR/MYnakdkQ==} + '@webgpu/types@0.1.49': + resolution: {integrity: sha512-NMmS8/DofhH/IFeW+876XrHVWel+J/vdcFCHLDqeJgkH9x0DeiwjVd8LcBdaxdG/T7Rf8VUAYsA8X1efMzLjRQ==} + abortcontroller-polyfill@1.7.5: resolution: {integrity: sha512-JMJ5soJWP18htbbxJjG7bG6yuI6pRhgJ0scHHTfkUjf6wjP912xZWvM+A4sJK3gqd9E8fcPbDnOefbA9Th/FIQ==} @@ -3043,6 +3055,10 @@ packages: resolution: {integrity: sha512-8XPvpAA8uyhfteu8pIvQxpJZ7SYYdpUivZpGy6sFsBuKRY/7rQGavedeB8aK+Zkyq6upMFVL/9AW6vOYzfRyLg==} engines: {node: '>=10'} + loglevel@1.9.2: + resolution: {integrity: sha512-HgMmCqIJSAKqo68l0rS2AanEWfkxaZ5wNiEFb5ggm08lDs9Xl2KxBlX3PTcaD2chBM1gXAYf491/M2Rv8Jwayg==} + engines: {node: '>= 0.6.0'} + loose-envify@1.4.0: resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} hasBin: true @@ -4673,6 +4689,10 @@ snapshots: '@lezer/lr': 1.4.0 json5: 2.2.3 + '@mlc-ai/web-llm@0.2.73': + dependencies: + loglevel: 1.9.2 + '@motionone/animation@10.17.0': dependencies: '@motionone/easing': 10.17.0 @@ -6145,6 +6165,8 @@ snapshots: '@vue/shared@3.3.4': {} + '@webgpu/types@0.1.49': {} + abortcontroller-polyfill@1.7.5: {} acorn@8.12.1: {} @@ -7203,6 +7225,8 @@ snapshots: chalk: 4.1.2 is-unicode-supported: 0.1.0 + loglevel@1.9.2: {} + loose-envify@1.4.0: dependencies: js-tokens: 4.0.0 diff --git a/src/contents/index/mounter.tsx b/src/contents/index/mounter.tsx index 5f6bbf59..fd097533 100644 --- a/src/contents/index/mounter.tsx +++ b/src/contents/index/mounter.tsx @@ -173,7 +173,7 @@ function createApp(roomId: string, plasmo: PlasmoSpec, info: StreamInfo): App { onClick: () => this.start() }, position: 'top-left', - duration: 6000000, + duration: Infinity, dismissible: false }) return diff --git a/src/hooks/life-cycle.ts b/src/hooks/life-cycle.ts index 79d8c242..13fabace 100644 --- a/src/hooks/life-cycle.ts +++ b/src/hooks/life-cycle.ts @@ -74,4 +74,34 @@ export function useTimeoutElement(before: JSX.Element, after: JSX.Element, timeo }, [after, before, timeout]); return element; +} + +/** + * A React hook that ensures the provided disposable resource is properly disposed of + * when the component is unmounted. + * + * @param disposable - The resource to be disposed of, which can be either a `Disposable` + * or an `AsyncDisposable`. The resource must implement either the `Symbol.dispose` or + * `Symbol.asyncDispose` method. + * + * @example + * ```typescript + * const MyComponent = () => { + * const disposableResource = useMemo(() => createDisposableResource(), []); + * withDisposable(disposableResource); + * + * return
My Component
; + * }; + * ``` + */ +export function useDisposable(disposable: Disposable | AsyncDisposable) { + useEffect(() => { + return () => { + if (!!disposable[Symbol.dispose]) { + disposable[Symbol.dispose](); + } else if (!!disposable[Symbol.asyncDispose]) { + disposable[Symbol.asyncDispose](); + } + } + }, []) } \ No newline at end of file diff --git a/src/llms/cloudflare-ai.ts b/src/llms/cloudflare-ai.ts index 260c855c..44acd030 100644 --- a/src/llms/cloudflare-ai.ts +++ b/src/llms/cloudflare-ai.ts @@ -41,7 +41,7 @@ export default class CloudFlareAI implements LLMProviders { console.warn('Cloudflare AI session is not supported') return { ...this, - [Symbol.dispose]: () => { } + [Symbol.asyncDispose]: async () => { } } } diff --git a/src/llms/gemini-nano.ts b/src/llms/gemini-nano.ts index f2da1f4e..d9dd23a6 100644 --- a/src/llms/gemini-nano.ts +++ b/src/llms/gemini-nano.ts @@ -19,7 +19,7 @@ export default class GeminiNano implements LLMProviders { return session.prompt(chat) } finally { console.debug('[gemini nano] done') - session[Symbol.dispose]() + await session[Symbol.asyncDispose]() } } @@ -33,7 +33,7 @@ export default class GeminiNano implements LLMProviders { } } finally { console.debug('[gemini nano] done') - session[Symbol.dispose]() + await session[Symbol.asyncDispose]() } } @@ -82,7 +82,7 @@ class GeminiAssistant implements Session { } } - [Symbol.dispose](): void { + async [Symbol.asyncDispose]() { this.assistant.destroy() } } @@ -105,7 +105,7 @@ class GeminiSummarizer implements Session { } } - [Symbol.dispose](): void { + async [Symbol.asyncDispose]() { this.summarizer.destroy() } diff --git a/src/llms/index.ts b/src/llms/index.ts index 15031b7a..c57b1212 100644 --- a/src/llms/index.ts +++ b/src/llms/index.ts @@ -3,21 +3,23 @@ import type { SettingSchema as LLMSchema } from '~options/fragments/llm' import cloudflare from './cloudflare-ai' import nano from './gemini-nano' import worker from './remote-worker' +import webllm from './web-llm' export interface LLMProviders { cumulative: boolean - validate(): Promise + validate(progress?: (p: number, t: string) => void): Promise prompt(chat: string): Promise promptStream(chat: string): AsyncGenerator asSession(): Promise> } -export type Session = Disposable & Omit +export type Session = AsyncDisposable & Omit const llms = { cloudflare, nano, - worker + worker, + webllm } export type LLMs = typeof llms diff --git a/src/llms/models.ts b/src/llms/models.ts new file mode 100644 index 00000000..23f2b9f1 --- /dev/null +++ b/src/llms/models.ts @@ -0,0 +1,33 @@ +import type { LLMTypes } from "~llms" + +export type ModelList = { + providers: LLMTypes[] + models: string[] +} + +const models: ModelList[] = [ + { + providers: ['worker', 'cloudflare'], + models: [ + '@cf/qwen/qwen1.5-14b-chat-awq', + '@cf/qwen/qwen1.5-7b-chat-awq', + '@cf/qwen/qwen1.5-1.8b-chat', + '@hf/google/gemma-7b-it', + '@hf/nousresearch/hermes-2-pro-mistral-7b' + ] + }, + { + providers: [ 'webllm' ], + models: [ + 'Qwen2-7B-Instruct-q4f32_1-MLC', + 'Qwen2.5-14B-Instruct-q4f16_1-MLC', + 'gemma-2-9b-it-q4f16_1-MLC', + 'Qwen2.5-3B-Instruct-q0f16-MLC', + 'Phi-3-mini-128k-instruct-q0f16-MLC', + 'Phi-3.5-mini-instruct-q4f16_1-MLC-1k' + ] + } +] + + +export default models \ No newline at end of file diff --git a/src/llms/remote-worker.ts b/src/llms/remote-worker.ts index 2866014e..6c859cfe 100644 --- a/src/llms/remote-worker.ts +++ b/src/llms/remote-worker.ts @@ -54,7 +54,7 @@ export default class RemoteWorker implements LLMProviders { console.warn('Remote worker session is not supported') return { ...this, - [Symbol.dispose]: () => { } + [Symbol.asyncDispose]: async () => { } } } diff --git a/src/llms/web-llm.ts b/src/llms/web-llm.ts new file mode 100644 index 00000000..6d643859 --- /dev/null +++ b/src/llms/web-llm.ts @@ -0,0 +1,87 @@ +import type { MLCEngine } from "@mlc-ai/web-llm"; +import type { LLMProviders, Session } from "~llms"; +import type { SettingSchema } from "~options/fragments/llm"; + +export default class WebLLM implements LLMProviders { + + private static readonly DEFAULT_MODEL: string = 'Qwen2-7B-Instruct-q4f32_1-MLC' + + private readonly model: string + + constructor(settings: SettingSchema) { + this.model = settings.model || WebLLM.DEFAULT_MODEL + } + + cumulative: boolean = true + + async validate(progresser?: (p: number, t: string) => void): Promise { + await this.initializeEngine(progresser) + } + + async prompt(chat: string): Promise { + const session = await this.asSession() + try { + console.debug('[web-llm] prompting: ', chat) + return session.prompt(chat) + } finally { + console.debug('[web-llm] done') + await session[Symbol.asyncDispose]() + } + } + + async *promptStream(chat: string): AsyncGenerator { + const session = await this.asSession() + try { + console.debug('[web-llm] prompting stream: ', chat) + const res = session.promptStream(chat) + for await (const chunk of res) { + yield chunk + } + } finally { + console.debug('[web-llm] done') + await session[Symbol.asyncDispose]() + } + } + + async asSession(): Promise> { + const engine = await this.initializeEngine() + return { + async prompt(chat: string) { + await engine.interruptGenerate() + const c = await engine.completions.create({ + prompt: chat, + max_tokens: 512, + temperature: 0.2, + }) + return c.choices[0]?.text ?? engine.getMessage() + }, + async *promptStream(chat: string): AsyncGenerator { + await engine.interruptGenerate() + const chunks = await engine.completions.create({ + prompt: chat, + max_tokens: 512, + temperature: 0.2, + stream: true + }) + for await (const chunk of chunks) { + yield chunk.choices[0]?.text || ""; + if (chunk.usage) { + console.debug('Usage:', chunk.usage) + } + } + }, + [Symbol.asyncDispose]: engine.unload + } + } + + private async initializeEngine(progresser?: (p: number, t: string) => void): Promise { + const { CreateMLCEngine } = await import('@mlc-ai/web-llm') + return CreateMLCEngine(this.model, { + initProgressCallback: (progress) => { + progresser?.(progress.progress, "正在下载AI模型到本地") + console.log('初始化进度:', progress) + } + }) + } + +} \ No newline at end of file diff --git a/src/options/components/Selector.tsx b/src/options/components/Selector.tsx index cc94049a..f7f5b204 100644 --- a/src/options/components/Selector.tsx +++ b/src/options/components/Selector.tsx @@ -15,6 +15,7 @@ export type SelectorProps = { disabled?: boolean options: SelectorOption[] className?: string + emptyValue?: string } @@ -37,7 +38,7 @@ function Selector(props: SelectorProps): JSX.Element {
!props.disabled && setOpen(!isOpen)}>
- {props.options.find((option) => option.value === props.value)?.label ?? String(props.value)} + {props.options.find((option) => option.value === props.value)?.label || String(props.value || (props.emptyValue || '请选择'))} diff --git a/src/options/fragments/llm.tsx b/src/options/fragments/llm.tsx index 561f2a1d..8c5ed42a 100644 --- a/src/options/fragments/llm.tsx +++ b/src/options/fragments/llm.tsx @@ -1,9 +1,10 @@ import { Button, Input, Tooltip, Typography } from "@material-tailwind/react" -import { Fragment, useState, type ChangeEvent, type ReactNode } from "react" +import { Fragment, useEffect, useMemo, useRef, useState, type ChangeEvent, type ReactNode } from "react" import { toast } from "sonner/dist" import type { StateProxy } from "~hooks/binding" import type { LLMTypes } from "~llms" import createLLMProvider from "~llms" +import models from "~llms/models" import Selector from "~options/components/Selector" export type SettingSchema = { @@ -55,21 +56,50 @@ function Hints({ children }: { children: ReactNode }): JSX.Element { function LLMSettings({ state, useHandler }: StateProxy): JSX.Element { const [validating, setValidating] = useState(false) + const toastValidating = useRef(null) const handler = useHandler, string>((e) => e.target.value) + const selectableModels = useMemo( + () => models + .filter(({ providers }) => providers.includes(state.provider)) + .flatMap(({ models }) => models) + .map(model => ({ label: model, value: model })), + [state.provider] + ) + + const onSwitchProvider = (provider: LLMTypes) => { + state.provider = provider + state.model = undefined // reset model + if (provider === 'webllm') { + toast.info('使用 WEBLLM 时,请确保你的电脑拥有足够的算力以供 AI 运行。', { position: 'top-center' }) + } + } + const onValidate = async () => { setValidating(true) - try { - const provider = createLLMProvider(state) - await provider.validate() - toast.success('配置可用!') - } catch (e) { - toast.error('配置不可用: ' + e.message) - } finally { - setValidating(false) - } + const provider = createLLMProvider(state) + const validation = provider.validate((p, t) => { + if (toastValidating.current) { + toast.loading(`${t}... (${Math.round(p * 100)}%)`, { + id: toastValidating.current + }) + } + }) + toast.dismiss() + toastValidating.current = toast.promise(validation, { + loading: `正在验证配置...`, + success: '配置可用!', + error: err => '配置不可用: ' + (err.message ?? err), + position: 'bottom-center', + duration: Infinity, + finally: () => setValidating(false) + }) + } + console.log('provider: ', state.provider) + console.log('model: ', state.model) + return ( @@ -77,11 +107,12 @@ function LLMSettings({ state, useHandler }: StateProxy): JSX.Elem data-testid="ai-provider" label="技术提供" value={state.provider} - onChange={e => state.provider = e} + onChange={onSwitchProvider} options={[ - { label: 'Cloudflare AI', value: 'cloudflare' }, - { label: '公共服务器', value: 'worker' }, - { label: 'Chrome 浏览器内置 AI', value: 'nano' } + { label: 'Cloudflare AI (云)', value: 'cloudflare' }, + { label: '公共服务器 (云)', value: 'worker' }, + { label: 'Chrome 浏览器内置 AI (本地)', value: 'nano' }, + { label: 'Web LLM (本地)', value: 'webllm' } ]} /> {state.provider === 'cloudflare' && ( @@ -110,27 +141,22 @@ function LLMSettings({ state, useHandler }: StateProxy): JSX.Elem /> )} - {['cloudflare', 'worker'].includes(state.provider) && ( + {state.provider === 'nano' && ( + + 点击此处 + 查看如何启用 Chrome 浏览器内置 AI + + )} + {selectableModels.length > 0 && ( data-testid="ai-model" label="模型提供" value={state.model} onChange={e => state.model = e} - options={[ - { label: '@cf/qwen/qwen1.5-14b-chat-awq', value: '@cf/qwen/qwen1.5-14b-chat-awq' }, - { label: '@cf/qwen/qwen1.5-7b-chat-awq', value: '@cf/qwen/qwen1.5-7b-chat-awq' }, - { label: '@cf/qwen/qwen1.5-1.8b-chat', value: '@cf/qwen/qwen1.5-1.8b-chat' }, - { label: '@hf/google/gemma-7b-it', value: '@hf/google/gemma-7b-it' }, - { label: '@hf/nousresearch/hermes-2-pro-mistral-7b', value: '@hf/nousresearch/hermes-2-pro-mistral-7b' } - ]} + options={selectableModels} + emptyValue="默认" /> )} - {state.provider === 'nano' && ( - - 点击此处 - 查看如何启用 Chrome 浏览器内置 AI - - )}