Skip to content

Commit

Permalink
added web-llm to llm provider
Browse files Browse the repository at this point in the history
  • Loading branch information
eric2788 committed Oct 24, 2024
1 parent 48d015a commit 35a32dc
Show file tree
Hide file tree
Showing 14 changed files with 311 additions and 39 deletions.
2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"@ffmpeg/ffmpeg": "^0.12.10",
"@ffmpeg/util": "^0.12.1",
"@material-tailwind/react": "^2.1.9",
"@mlc-ai/web-llm": "^0.2.73",
"@plasmohq/messaging": "^0.6.2",
"@plasmohq/storage": "^1.9.3",
"@react-hooks-library/core": "^0.5.2",
Expand Down Expand Up @@ -64,6 +65,7 @@
"@types/react": "18.2.37",
"@types/react-dom": "18.2.15",
"@types/semver": "^7.5.8",
"@webgpu/types": "^0.1.49",
"dotenv": "^16.4.5",
"esbuild": "^0.20.2",
"gify-parse": "^1.0.7",
Expand Down
24 changes: 24 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/contents/index/mounter.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ function createApp(roomId: string, plasmo: PlasmoSpec, info: StreamInfo): App {
onClick: () => this.start()
},
position: 'top-left',
duration: 6000000,
duration: Infinity,
dismissible: false
})
return
Expand Down
30 changes: 30 additions & 0 deletions src/hooks/life-cycle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,34 @@ export function useTimeoutElement(before: JSX.Element, after: JSX.Element, timeo
}, [after, before, timeout]);

return element;
}

/**
* A React hook that ensures the provided disposable resource is properly disposed of
* when the component is unmounted.
*
* @param disposable - The resource to be disposed of, which can be either a `Disposable`
* or an `AsyncDisposable`. The resource must implement either the `Symbol.dispose` or
* `Symbol.asyncDispose` method.
*
* @example
* ```typescript
* const MyComponent = () => {
* const disposableResource = useMemo(() => createDisposableResource(), []);
* withDisposable(disposableResource);
*
* return <div>My Component</div>;
* };
* ```
*/
export function useDisposable(disposable: Disposable | AsyncDisposable) {
useEffect(() => {
return () => {
if (!!disposable[Symbol.dispose]) {
disposable[Symbol.dispose]();
} else if (!!disposable[Symbol.asyncDispose]) {
disposable[Symbol.asyncDispose]();
}
}
}, [])
}
2 changes: 1 addition & 1 deletion src/llms/cloudflare-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export default class CloudFlareAI implements LLMProviders {
console.warn('Cloudflare AI session is not supported')
return {
...this,
[Symbol.dispose]: () => { }
[Symbol.asyncDispose]: async () => { }
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/llms/gemini-nano.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export default class GeminiNano implements LLMProviders {
return session.prompt(chat)
} finally {
console.debug('[gemini nano] done')
session[Symbol.dispose]()
await session[Symbol.asyncDispose]()
}
}

Expand All @@ -33,7 +33,7 @@ export default class GeminiNano implements LLMProviders {
}
} finally {
console.debug('[gemini nano] done')
session[Symbol.dispose]()
await session[Symbol.asyncDispose]()
}
}

Expand Down Expand Up @@ -82,7 +82,7 @@ class GeminiAssistant implements Session<LLMProviders> {
}
}

[Symbol.dispose](): void {
async [Symbol.asyncDispose]() {
this.assistant.destroy()
}
}
Expand All @@ -105,7 +105,7 @@ class GeminiSummarizer implements Session<LLMProviders> {
}
}

[Symbol.dispose](): void {
async [Symbol.asyncDispose]() {
this.summarizer.destroy()
}

Expand Down
8 changes: 5 additions & 3 deletions src/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@ import type { SettingSchema as LLMSchema } from '~options/fragments/llm'
import cloudflare from './cloudflare-ai'
import nano from './gemini-nano'
import worker from './remote-worker'
import webllm from './web-llm'

export interface LLMProviders {
cumulative: boolean
validate(): Promise<void>
validate(progress?: (p: number, t: string) => void): Promise<void>
prompt(chat: string): Promise<string>
promptStream(chat: string): AsyncGenerator<string>
asSession(): Promise<Session<LLMProviders>>
}

export type Session<T> = Disposable & Omit<T, 'asSession' | 'validate' | 'cumulative'>
export type Session<T> = AsyncDisposable & Omit<T, 'asSession' | 'validate' | 'cumulative'>

const llms = {
cloudflare,
nano,
worker
worker,
webllm
}

export type LLMs = typeof llms
Expand Down
33 changes: 33 additions & 0 deletions src/llms/models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import type { LLMTypes } from "~llms"

export type ModelList = {
providers: LLMTypes[]
models: string[]
}

const models: ModelList[] = [
{
providers: ['worker', 'cloudflare'],
models: [
'@cf/qwen/qwen1.5-14b-chat-awq',
'@cf/qwen/qwen1.5-7b-chat-awq',
'@cf/qwen/qwen1.5-1.8b-chat',
'@hf/google/gemma-7b-it',
'@hf/nousresearch/hermes-2-pro-mistral-7b'
]
},
{
providers: [ 'webllm' ],
models: [
'Qwen2-7B-Instruct-q4f32_1-MLC',
'Qwen2.5-14B-Instruct-q4f16_1-MLC',
'gemma-2-9b-it-q4f16_1-MLC',
'Qwen2.5-3B-Instruct-q0f16-MLC',
'Phi-3-mini-128k-instruct-q0f16-MLC',
'Phi-3.5-mini-instruct-q4f16_1-MLC-1k'
]
}
]


export default models
2 changes: 1 addition & 1 deletion src/llms/remote-worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export default class RemoteWorker implements LLMProviders {
console.warn('Remote worker session is not supported')
return {
...this,
[Symbol.dispose]: () => { }
[Symbol.asyncDispose]: async () => { }
}
}

Expand Down
87 changes: 87 additions & 0 deletions src/llms/web-llm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import type { MLCEngine } from "@mlc-ai/web-llm";
import type { LLMProviders, Session } from "~llms";
import type { SettingSchema } from "~options/fragments/llm";

export default class WebLLM implements LLMProviders {

private static readonly DEFAULT_MODEL: string = 'Qwen2-7B-Instruct-q4f32_1-MLC'

private readonly model: string

constructor(settings: SettingSchema) {
this.model = settings.model || WebLLM.DEFAULT_MODEL
}

cumulative: boolean = true

async validate(progresser?: (p: number, t: string) => void): Promise<void> {
await this.initializeEngine(progresser)
}

async prompt(chat: string): Promise<string> {
const session = await this.asSession()
try {
console.debug('[web-llm] prompting: ', chat)
return session.prompt(chat)
} finally {
console.debug('[web-llm] done')
await session[Symbol.asyncDispose]()
}
}

async *promptStream(chat: string): AsyncGenerator<string> {
const session = await this.asSession()
try {
console.debug('[web-llm] prompting stream: ', chat)
const res = session.promptStream(chat)
for await (const chunk of res) {
yield chunk
}
} finally {
console.debug('[web-llm] done')
await session[Symbol.asyncDispose]()
}
}

async asSession(): Promise<Session<LLMProviders>> {
const engine = await this.initializeEngine()
return {
async prompt(chat: string) {
await engine.interruptGenerate()
const c = await engine.completions.create({
prompt: chat,
max_tokens: 512,
temperature: 0.2,
})
return c.choices[0]?.text ?? engine.getMessage()
},
async *promptStream(chat: string): AsyncGenerator<string> {
await engine.interruptGenerate()
const chunks = await engine.completions.create({
prompt: chat,
max_tokens: 512,
temperature: 0.2,
stream: true
})
for await (const chunk of chunks) {
yield chunk.choices[0]?.text || "";
if (chunk.usage) {
console.debug('Usage:', chunk.usage)
}
}
},
[Symbol.asyncDispose]: engine.unload
}
}

private async initializeEngine(progresser?: (p: number, t: string) => void): Promise<MLCEngine> {
const { CreateMLCEngine } = await import('@mlc-ai/web-llm')
return CreateMLCEngine(this.model, {
initProgressCallback: (progress) => {
progresser?.(progress.progress, "正在下载AI模型到本地")
console.log('初始化进度:', progress)
}
})
}

}
3 changes: 2 additions & 1 deletion src/options/components/Selector.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export type SelectorProps<T> = {
disabled?: boolean
options: SelectorOption<T>[]
className?: string
emptyValue?: string
}


Expand All @@ -37,7 +38,7 @@ function Selector<T = any>(props: SelectorProps<T>): JSX.Element {
<label className="text-sm ml-1 font-medium text-gray-900 dark:text-white">{props.label}</label>
<div ref={dropdownRef} className={`mt-2 ${props.disabled ? 'cursor-not-allowed' : 'cursor-pointer'}`} onClick={() => !props.disabled && setOpen(!isOpen)}>
<div className={`inline-flex justify-between h-full w-full rounded-md border border-gray-300 dark:border-gray-600 shadow-sm px-4 py-2 text-sm font-medium text-gray-700 dark:text-white ${props.disabled ? 'opacity-50 bg-transparent' : 'bg-white dark:bg-gray-800 hover:bg-gray-50 dark:hover:bg-gray-900'} focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-offset-gray-100 dark:focus:ring-gray-500`}>
{props.options.find((option) => option.value === props.value)?.label ?? String(props.value)}
{props.options.find((option) => option.value === props.value)?.label || String(props.value || (props.emptyValue || '请选择'))}
<svg className="-mr-1 ml-2 h-5 w-5" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
<path fillRule="evenodd" d="M5.293 7.293a1 1 0 011.414 0L10 10.586l3.293-3.293a1 1 0 111.414 1.414l-4 4a1 1 0 01-1.414 0l-4-4a1 1 0 010-1.414z" clipRule="evenodd" />
</svg>
Expand Down
Loading

0 comments on commit 35a32dc

Please sign in to comment.