Skip to content

Commit

Permalink
added web-llm as llm provider
Browse files Browse the repository at this point in the history
  • Loading branch information
eric2788 committed Oct 23, 2024
1 parent 9dfbf68 commit 5a795d2
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 38 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"@ffmpeg/ffmpeg": "^0.12.10",
"@ffmpeg/util": "^0.12.1",
"@material-tailwind/react": "^2.1.9",
"@mlc-ai/web-llm": "^0.2.73",
"@plasmohq/messaging": "^0.6.2",
"@plasmohq/storage": "^1.9.3",
"@react-hooks-library/core": "^0.5.2",
Expand Down
16 changes: 16 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/contents/index/mounter.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ function createApp(roomId: string, plasmo: PlasmoSpec, info: StreamInfo): App {
onClick: () => this.start()
},
position: 'top-left',
duration: 6000000,
duration: Infinity,
dismissible: false
})
return
Expand Down
2 changes: 1 addition & 1 deletion src/llms/cloudflare-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export default class CloudFlareAI implements LLMProviders {
console.warn('Cloudflare AI session is not supported')
return {
...this,
[Symbol.dispose]: () => { }
[Symbol.asyncDispose]: async () => { }
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/llms/gemini-nano.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export default class GeminiNano implements LLMProviders {
return session.prompt(chat)
} finally {
console.debug('[gemini nano] done')
session[Symbol.dispose]()
await session[Symbol.asyncDispose]()
}
}

Expand All @@ -33,7 +33,7 @@ export default class GeminiNano implements LLMProviders {
}
} finally {
console.debug('[gemini nano] done')
session[Symbol.dispose]()
await session[Symbol.asyncDispose]()
}
}

Expand Down Expand Up @@ -82,7 +82,7 @@ class GeminiAssistant implements Session<LLMProviders> {
}
}

[Symbol.dispose](): void {
async [Symbol.asyncDispose]() {
this.assistant.destroy()
}
}
Expand All @@ -105,7 +105,7 @@ class GeminiSummarizer implements Session<LLMProviders> {
}
}

[Symbol.dispose](): void {
async [Symbol.asyncDispose]() {
this.summarizer.destroy()
}

Expand Down
6 changes: 4 additions & 2 deletions src/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { SettingSchema as LLMSchema } from '~options/fragments/llm'
import cloudflare from './cloudflare-ai'
import nano from './gemini-nano'
import worker from './remote-worker'
import webllm from './web-llm'

export interface LLMProviders {
cumulative: boolean
Expand All @@ -12,12 +13,13 @@ export interface LLMProviders {
asSession(): Promise<Session<LLMProviders>>
}

export type Session<T> = Disposable & Omit<T, 'asSession' | 'validate' | 'cumulative'>
export type Session<T> = AsyncDisposable & Omit<T, 'asSession' | 'validate' | 'cumulative'>

const llms = {
cloudflare,
nano,
worker
worker,
webllm
}

export type LLMs = typeof llms
Expand Down
32 changes: 32 additions & 0 deletions src/llms/models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import type { LLMTypes } from "~llms"

export type ModelList = {
providers: LLMTypes[]
models: string[]
}

const models: ModelList[] = [
{
providers: ['worker', 'cloudflare'],
models: [
'@cf/qwen/qwen1.5-14b-chat-awq',
'@cf/qwen/qwen1.5-7b-chat-awq',
'@cf/qwen/qwen1.5-1.8b-chat',
'@hf/google/gemma-7b-it',
'@hf/nousresearch/hermes-2-pro-mistral-7b'
]
},
{
providers: [ 'webllm' ],
models: [
'Qwen2-7B-Instruct-q4f32_1-MLC',
'Llama-3.1-8B-Instruct-q4f32_1-MLC',
'Qwen2.5-14B-Instruct-q4f16_1-MLC',
'gemma-2-9b-it-q4f16_1-MLC',
'Qwen2.5-3B-Instruct-q0f16-MLC'
]
}
]


export default models
2 changes: 1 addition & 1 deletion src/llms/remote-worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export default class RemoteWorker implements LLMProviders {
console.warn('Remote worker session is not supported')
return {
...this,
[Symbol.dispose]: () => { }
[Symbol.asyncDispose]: async () => { }
}
}

Expand Down
84 changes: 84 additions & 0 deletions src/llms/web-llm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import { CreateMLCEngine, MLCEngine } from "@mlc-ai/web-llm";
import type { LLMProviders, Session } from "~llms";
import type { SettingSchema } from "~options/fragments/llm";


export default class WebLLM implements LLMProviders {

private static readonly DEFAULT_MODEL: string = 'Qwen2-7B-Instruct-q4f32_1-MLC'

private readonly model: string
private readonly engine: MLCEngine

constructor(settings: SettingSchema) {
this.model = settings.model || WebLLM.DEFAULT_MODEL
this.engine = new MLCEngine()
}

cumulative: boolean = true

async validate(): Promise<void> {
await this.engine.reload(this.model)
}

async prompt(chat: string): Promise<string> {
const session = await this.asSession()
try {
console.debug('[web-llm] prompting: ', chat)
return session.prompt(chat)
} finally {
console.debug('[web-llm] done')
await session[Symbol.asyncDispose]()
}
}

async *promptStream(chat: string): AsyncGenerator<string> {
const session = await this.asSession()
try {
console.debug('[web-llm] prompting stream: ', chat)
const res = session.promptStream(chat)
for await (const chunk of res) {
yield chunk
}
} finally {
console.debug('[web-llm] done')
await session[Symbol.asyncDispose]()
}
}

async asSession(): Promise<Session<LLMProviders>> {
const engine = await CreateMLCEngine(this.model, {
initProgressCallback: (progress) => {
console.log('初始化进度:', progress)
}
})
return {
async prompt(chat: string) {
await engine.interruptGenerate()
const c = await engine.completions.create({
prompt: chat,
max_tokens: 512,
temperature: 0.2,
})
return c.choices[0]?.text ?? engine.getMessage()
},
async *promptStream(chat: string): AsyncGenerator<string> {
await engine.interruptGenerate()
const chunks = await engine.completions.create({
prompt: chat,
max_tokens: 512,
temperature: 0.2,
stream: true
})
for await (const chunk of chunks) {
yield chunk.choices[0]?.text|| "";
if (chunk.usage) {
console.debug('Usage:', chunk.usage)
}
}
},
[Symbol.asyncDispose]: engine.unload
}
}

}
3 changes: 2 additions & 1 deletion src/options/components/Selector.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export type SelectorProps<T> = {
disabled?: boolean
options: SelectorOption<T>[]
className?: string
emptyValue?: string
}


Expand All @@ -37,7 +38,7 @@ function Selector<T = any>(props: SelectorProps<T>): JSX.Element {
<label className="text-sm ml-1 font-medium text-gray-900 dark:text-white">{props.label}</label>
<div ref={dropdownRef} className={`mt-2 ${props.disabled ? 'cursor-not-allowed' : 'cursor-pointer'}`} onClick={() => !props.disabled && setOpen(!isOpen)}>
<div className={`inline-flex justify-between h-full w-full rounded-md border border-gray-300 dark:border-gray-600 shadow-sm px-4 py-2 text-sm font-medium text-gray-700 dark:text-white ${props.disabled ? 'opacity-50 bg-transparent' : 'bg-white dark:bg-gray-800 hover:bg-gray-50 dark:hover:bg-gray-900'} focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-offset-gray-100 dark:focus:ring-gray-500`}>
{props.options.find((option) => option.value === props.value)?.label ?? String(props.value)}
{props.options.find((option) => option.value === props.value)?.label || String(props.value || (props.emptyValue || '请选择'))}
<svg className="-mr-1 ml-2 h-5 w-5" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
<path fillRule="evenodd" d="M5.293 7.293a1 1 0 011.414 0L10 10.586l3.293-3.293a1 1 0 111.414 1.414l-4 4a1 1 0 01-1.414 0l-4-4a1 1 0 010-1.414z" clipRule="evenodd" />
</svg>
Expand Down
74 changes: 46 additions & 28 deletions src/options/fragments/llm.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Button, Input, Tooltip, Typography } from "@material-tailwind/react"
import { Fragment, useState, type ChangeEvent, type ReactNode } from "react"
import { Fragment, useMemo, useState, type ChangeEvent, type ReactNode } from "react"
import { toast } from "sonner/dist"
import type { StateProxy } from "~hooks/binding"
import type { LLMTypes } from "~llms"
import createLLMProvider from "~llms"
import models from "~llms/models"
import Selector from "~options/components/Selector"

export type SettingSchema = {
Expand Down Expand Up @@ -57,31 +58,53 @@ function LLMSettings({ state, useHandler }: StateProxy<SettingSchema>): JSX.Elem
const [validating, setValidating] = useState(false)
const handler = useHandler<ChangeEvent<HTMLInputElement>, string>((e) => e.target.value)

const selectableModels = useMemo(
() => models
.filter(({ providers }) => providers.includes(state.provider))
.flatMap(({ models }) => models)
.map(model => ({ label: model, value: model })),
[state.provider]
)

const onSwitchProvider = (provider: LLMTypes) => {
state.provider = provider
state.model = undefined // reset model
if (provider === 'webllm') {
toast.info('使用 WEBLLM 时,请确保你的电脑拥有足够的算力以供 AI 运行。')
}
}

const onValidate = async () => {
setValidating(true)
try {
const provider = createLLMProvider(state)
await provider.validate()
toast.success('配置可用!')
} catch (e) {
toast.error('配置不可用: ' + e.message)
} finally {
setValidating(false)
}
const provider = createLLMProvider(state)
const validation = provider.validate()
toast.dismiss()
toast.promise(validation, {
loading: '正在验证配置...',
success: '配置可用!',
error: err => '配置不可用: ' + (err.message ?? err),
position: 'bottom-center',
duration: Infinity,
finally: () => setValidating(false)
})
}

console.log('provider: ', state.provider)
console.log('model: ', state.model)

return (
<Fragment>
<Selector<typeof state.provider>
className="col-span-2"
data-testid="ai-provider"
label="技术提供"
value={state.provider}
onChange={e => state.provider = e}
onChange={onSwitchProvider}
options={[
{ label: 'Cloudflare AI', value: 'cloudflare' },
{ label: '公共服务器', value: 'worker' },
{ label: 'Chrome 浏览器内置 AI', value: 'nano' }
{ label: 'Cloudflare AI (云)', value: 'cloudflare' },
{ label: '公共服务器 (云)', value: 'worker' },
{ label: 'Chrome 浏览器内置 AI (本地)', value: 'nano' },
{ label: 'Web LLM (本地)', value: 'webllm' }
]}
/>
{state.provider === 'cloudflare' && (
Expand Down Expand Up @@ -110,27 +133,22 @@ function LLMSettings({ state, useHandler }: StateProxy<SettingSchema>): JSX.Elem
/>
</Fragment>
)}
{['cloudflare', 'worker'].includes(state.provider) && (
{state.provider === 'nano' && (
<Hints>
<Typography className="underline" as="a" href="https://juejin.cn/post/7401036139384143910" target="_blank">点击此处</Typography>
查看如何启用 Chrome 浏览器内置 AI
</Hints>
)}
{selectableModels.length > 0 && (
<Selector<string>
data-testid="ai-model"
label="模型提供"
value={state.model}
onChange={e => state.model = e}
options={[
{ label: '@cf/qwen/qwen1.5-14b-chat-awq', value: '@cf/qwen/qwen1.5-14b-chat-awq' },
{ label: '@cf/qwen/qwen1.5-7b-chat-awq', value: '@cf/qwen/qwen1.5-7b-chat-awq' },
{ label: '@cf/qwen/qwen1.5-1.8b-chat', value: '@cf/qwen/qwen1.5-1.8b-chat' },
{ label: '@hf/google/gemma-7b-it', value: '@hf/google/gemma-7b-it' },
{ label: '@hf/nousresearch/hermes-2-pro-mistral-7b', value: '@hf/nousresearch/hermes-2-pro-mistral-7b' }
]}
options={selectableModels}
emptyValue="默认"
/>
)}
{state.provider === 'nano' && (
<Hints>
<Typography className="underline" as="a" href="https://juejin.cn/post/7401036139384143910" target="_blank">点击此处</Typography>
查看如何启用 Chrome 浏览器内置 AI
</Hints>
)}
<div className="col-span-2">
<Button disabled={validating} onClick={onValidate} color="blue" size="lg" className="group flex items-center justify-center gap-3 text-[1rem] hover:shadow-lg">
验证是否可用
Expand Down

0 comments on commit 5a795d2

Please sign in to comment.