From b3b8b0a54c61b4d6b13a72845908b52828987076 Mon Sep 17 00:00:00 2001 From: eric2788 Date: Fri, 18 Oct 2024 18:09:17 +0800 Subject: [PATCH 01/15] init commit for ai feature --- src/llms/cloudflare-ai.ts | 59 ++++++++ src/llms/gemini-nano.ts | 89 +++++++++++ src/llms/index.ts | 27 ++++ src/llms/remote-worker.ts | 54 +++++++ src/types/extends/chrome-ai.d.ts | 247 +++++++++++++++++++++++++++++++ tests/modules/llm.js | 7 + tests/modules/utils.js | 3 + tests/units/llm.spec.ts | 63 ++++++++ 8 files changed, 549 insertions(+) create mode 100644 src/llms/cloudflare-ai.ts create mode 100644 src/llms/gemini-nano.ts create mode 100644 src/llms/index.ts create mode 100644 src/llms/remote-worker.ts create mode 100644 src/types/extends/chrome-ai.d.ts create mode 100644 tests/modules/llm.js create mode 100644 tests/units/llm.spec.ts diff --git a/src/llms/cloudflare-ai.ts b/src/llms/cloudflare-ai.ts new file mode 100644 index 00000000..f3f2f477 --- /dev/null +++ b/src/llms/cloudflare-ai.ts @@ -0,0 +1,59 @@ +import type { LLMProviders, Session } from "~llms"; + +export default class CloudflareAI implements LLMProviders { + + constructor( + private readonly accountId: string, + private readonly apiToken: String, + private readonly model = '@cf/facebook/bart-large-cnn' // text summarization model + ) { } + + async validate(): Promise { + const res = await fetch(`https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/models/search?per_page=1`, { + headers: { + Authorization: `Bearer ${this.apiToken}` + } + }) + const json = await res.json() + if (!json.success) throw new Error('Cloudflare API 验证失败') + } + + async prompt(chat: string): Promise { + const res = await fetch(`https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/run/${this.model}`, { + headers: { + Authorization: `Bearer ${this.apiToken}` + }, + body: JSON.stringify({ prompt: chat }) + }) + const json = await res.json() + return json.response + } + + async *promptStream(chat: string): AsyncGenerator { + const res = await fetch(`https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/run/${this.model}`, { + headers: { + Authorization: `Bearer ${this.apiToken}` + }, + body: JSON.stringify({ prompt: chat, stream: true }) + }) + if (!res.body) throw new Error('Cloudflare AI response body is not readable') + const reader = res.body.getReader() + const decoder = new TextDecoder('utf-8', { ignoreBOM: true }) + while (true) { + const { done, value } = await reader.read() + if (done) break + const { response } = JSON.parse(decoder.decode(value, { stream: true })) + yield response + } + } + + async asSession(): Promise> { + console.warn('Cloudflare AI session is not supported') + return { + ...this, + [Symbol.dispose]: () => { } + } + } + + +} \ No newline at end of file diff --git a/src/llms/gemini-nano.ts b/src/llms/gemini-nano.ts new file mode 100644 index 00000000..7ab3fd4a --- /dev/null +++ b/src/llms/gemini-nano.ts @@ -0,0 +1,89 @@ +import type { LLMProviders, Session } from "~llms" + +export default class GeminiNano implements LLMProviders { + + async validate(): Promise { + if (!window.ai) throw new Error('你的浏览器没有启用 AI 功能') + if (!window.ai.languageModel && + !window.ai.assistant && + !window.ai.summarizer + ) throw new Error('你的浏览器没有启用 AI 功能') + } + + async prompt(chat: string): Promise { + using session = await this.asSession() + return session.prompt(chat) + } + + async *promptStream(chat: string): AsyncGenerator { + using session = await this.asSession() + return session.promptStream(chat) + } + + async asSession(): Promise> { + + if (window.ai.assistant || window.ai.languageModel) { + const assistant = window.ai.assistant ?? window.ai.languageModel + const capabilities = await assistant.capabilities() + if (capabilities.available === 'readily') { + return new GeminiAssistant(await assistant.create()) + } else { + console.warn('AI Assistant 当前不可用: ', capabilities) + } + } + + if (window.ai.summarizer) { + const summarizer = window.ai.summarizer + const capabilities = await summarizer.capabilities() + if (capabilities.available === 'readily') { + return new GeminiSummarizer(await summarizer.create()) + } else { + console.warn('AI Summarizer 当前不可用: ', capabilities) + } + } + + throw new Error('你的浏览器 AI 功能当前不可用') + } +} + +class GeminiAssistant implements Session { + + constructor(private readonly assistant: AIAssistant) { } + + prompt(chat: string): Promise { + return this.assistant.prompt(chat) + } + + async *promptStream(chat: string): AsyncGenerator { + const stream = this.assistant.promptStreaming(chat) + for await (const chunk of stream) { + yield chunk + } + } + + [Symbol.dispose](): void { + this.assistant.destroy() + } +} + + +class GeminiSummarizer implements Session { + + constructor(private readonly summarizer: AISummarizer) { } + + prompt(chat: string): Promise { + return this.summarizer.summarize(chat) + } + + async *promptStream(chat: string): AsyncGenerator { + const stream = this.summarizer.summarizeStreaming(chat) + for await (const chunk of stream) { + yield chunk + } + } + + [Symbol.dispose](): void { + this.summarizer.destroy() + } + +} diff --git a/src/llms/index.ts b/src/llms/index.ts new file mode 100644 index 00000000..92373757 --- /dev/null +++ b/src/llms/index.ts @@ -0,0 +1,27 @@ +import cloudflare from './cloudflare-ai' +import nano from './gemini-nano' +import worker from './remote-worker' + +export interface LLMProviders { + validate(): Promise + prompt(chat: string): Promise + promptStream(chat: string): AsyncGenerator + asSession(): Promise> +} + +export type Session = Disposable & Omit + +const llms = { + cloudflare, + nano, + worker +} + +export type LLMTypes = keyof typeof llms + +export async function createLLMProvider(type: LLMTypes, ...args: any[]): Promise { + const LLM = llms[type].bind(this, ...args) + return new LLM() +} + +export default createLLMProvider \ No newline at end of file diff --git a/src/llms/remote-worker.ts b/src/llms/remote-worker.ts new file mode 100644 index 00000000..42a4a8ba --- /dev/null +++ b/src/llms/remote-worker.ts @@ -0,0 +1,54 @@ +import type { LLMProviders, Session } from "~llms"; + + +// for my worker, so limited usage +export default class RemoteWorker implements LLMProviders { + + async validate(): Promise { + const res = await fetch('https://llm.ericlamm.xyz/status') + const json = await res.json() + if (json.status !== 'working') { + throw new Error('Remote worker is not working') + } + } + + async prompt(chat: string): Promise { + const res = await fetch('https://llm.ericlamm.xyz/', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ prompt: chat }) + }) + const json = await res.json() + return json.response + } + + async *promptStream(chat: string): AsyncGenerator { + const res = await fetch('https://llm.ericlamm.xyz/', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ prompt: chat, stream: true }) + }) + if (!res.body) throw new Error('Remote worker response body is not readable') + const reader = res.body.getReader() + const decoder = new TextDecoder('utf-8', { ignoreBOM: true }) + while (true) { + const { done, value } = await reader.read() + if (done) break + const { response } = JSON.parse(decoder.decode(value, { stream: true })) + yield response + } + } + + async asSession(): Promise> { + console.warn('Remote worker session is not supported') + return { + ...this, + [Symbol.dispose]: () => { } + } + } + +} \ No newline at end of file diff --git a/src/types/extends/chrome-ai.d.ts b/src/types/extends/chrome-ai.d.ts new file mode 100644 index 00000000..35327564 --- /dev/null +++ b/src/types/extends/chrome-ai.d.ts @@ -0,0 +1,247 @@ +/** + * DEPRECATED! PLEASE USE https://www.npmjs.com/package/@types/dom-chromium-ai + */ + + +/** + * See https://github.com/explainers-by-googlers/writing-assistance-apis + * and https://github.com/explainers-by-googlers/prompt-api + */ +declare global { + interface AI { + assistant: AIAssistantFactory; + languageModel: AIAssistantFactory; // some old versions of chrome use this name + summarizer: AISummarizerFactory; + writer: AIWriterFactory; + rewriter: AIRewriterFactory; + } + + interface AIAssistantFactory { + create( options?: AIAssistantCreateOptions ): Promise< AIAssistant >; + capabilities(): Promise< AIAssistantCapabilities >; + } + + interface AIAssistantCapabilities { + available: AICapabilityAvailability; + defaultTopK?: number | null; + maxTopK?: number | null; + defaultTemperature?: + | 0.1 + | 0.2 + | 0.3 + | 0.4 + | 0.5 + | 0.6 + | 0.7 + | 0.8 + | 0.9 + | 1.0 + | null; + supportsLanguage( languageTag: string ): AICapabilityAvailability; + } + + interface AIAssistant { + prompt( + input: string, + options?: AIAssistantPromptOptions + ): Promise< string >; + promptStreaming( + input: string, + options?: AIAssistantPromptOptions + ): AIReadableStream; + countPromptTokens( + input: string, + options?: AIAssistantPromptOptions + ): number; + destroy(): void; + clone(): AIAssistant; + } + + interface AISummarizerFactory { + create( options?: AISummarizerCreateOptions ): Promise< AISummarizer >; + capabilities(): Promise< AISummarizerCapabilities >; + } + + interface AISummarizerCapabilities { + available: AICapabilityAvailability; + supportsType( tone: AISummarizerType ): AICapabilityAvailability; + supportsFormat( format: AISummarizerFormat ): AICapabilityAvailability; + supportsLength( length: AISummarizerLength ): AICapabilityAvailability; + supportsInputLanguage( languageTag: string ): AICapabilityAvailability; + } + + interface AISummarizerCreateOptions { + signal?: AbortSignal; + monitor?: AICreateMonitorCallback; + sharedContext?: string; + type?: AISummarizerType; // Default is 'key-points'. + format?: AISummarizerFormat; // Default is 'markdown'. + length?: AISummarizerLength; // Default is 'short'. + } + + type AISummarizerType = 'tl;dr' | 'key-points' | 'teaser' | 'headline'; + type AISummarizerFormat = 'plain-text' | 'markdown'; + type AISummarizerLength = 'short' | 'medium' | 'long'; + + interface AISummarizerSummarizeOptions { + signal?: AbortSignal; + context?: string; + } + + interface AISummarizer extends EventTarget { + ready: Promise< undefined >; + summarize( + input: string, + options?: AISummarizerSummarizeOptions + ): Promise< string >; + summarizeStreaming( + input: string, + options?: AISummarizerSummarizeOptions + ): AIReadableStream; + destroy(): void; + } + + interface AIWriterFactory { + create( options?: AIWriterCreateOptions ): Promise< AIWriter >; + capabilities(): Promise< AIWriterCapabilities >; + } + + interface AIWriterCapabilities { + available: AICapabilityAvailability; + supportsTone( tone: AIWriterTone ): AICapabilityAvailability; + supportsFormat( format: AIWriterFormat ): AICapabilityAvailability; + supportsLength( length: AIWriterLength ): AICapabilityAvailability; + supportsInputLanguage( languageTag: string ): AICapabilityAvailability; + } + + interface AIWriterCreateOptions { + signal?: AbortSignal; + monitor?: AICreateMonitorCallback; + sharedContext?: string; + tone?: AIWriterTone; // Default is 'key-points'. + format?: AIWriterFormat; // Default is 'markdown'. + length?: AIWriterLength; // Default is 'short'. + } + + // TODO: What about 'key-points'? File issue. + type AIWriterTone = 'formal' | 'neutral' | 'casual'; + type AIWriterFormat = 'plain-text' | 'markdown'; + type AIWriterLength = 'short' | 'medium' | 'long'; + + interface AIWriterWriteOptions { + signal?: AbortSignal; + context?: string; + } + + interface AIWriter { + write( + writingTask: string, + options?: AIWriterWriteOptions + ): Promise< string >; + writeStreaming( + writingTask: string, + options?: AIWriterWriteOptions + ): AIReadableStream; + tone: AIWriterTone; + format: AIWriterFormat; + length: AIWriterLength; + destroy(): void; + } + + interface AIRewriterFactory { + create( options?: AIRewriterCreateOptions ): Promise< AIRewriter >; + capabilities(): Promise< AIRewriterCapabilities >; + } + + interface AIRewriterCapabilities { + available: AICapabilityAvailability; + supportsTone( tone: AIRewriterTone ): AICapabilityAvailability; + supportsFormat( format: AIRewriterFormat ): AICapabilityAvailability; + supportsLength( length: AIRewriterLength ): AICapabilityAvailability; + supportsInputLanguage( languageTag: string ): AICapabilityAvailability; + } + + interface AIRewriterCreateOptions { + signal?: AbortSignal; + monitor?: AICreateMonitorCallback; + sharedContext?: string; + tone?: AIRewriterTone; // Default is 'as-is'. + format?: AIRewriterFormat; // Default is 'as-is'. + length?: AIRewriterLength; // Default is 'as-is'. + } + + type AIRewriterTone = 'as-is' | 'more-formal' | 'more-casual'; + type AIRewriterFormat = 'as-is' | 'plain-text' | 'markdown'; + type AIRewriterLength = 'as-is' | 'shorter' | 'longer'; + + interface AIRewriterRewriteOptions { + signal?: AbortSignal; + context?: string; + } + + interface AIRewriter { + rewrite( + writingTask: string, + options?: AIRewriterRewriteOptions + ): Promise< string >; + rewriteStreaming( + writingTask: string, + options?: AIRewriterRewriteOptions + ): AIReadableStream; + tone: AIRewriterTone; + format: AIRewriterFormat; + length: AIRewriterLength; + destroy(): void; + } + + type AICapabilityAvailability = 'readily' | 'after-download' | 'no'; + + interface InitialPrompt { + role: string; + content: string; + } + + interface AssistantPrompt extends InitialPrompt { + role: 'assistant'; + content: string; + } + + interface UserPrompt extends InitialPrompt { + role: 'user'; + content: string; + } + + interface SystemPrompt extends InitialPrompt { + role: 'system'; + } + + interface AIAssistantCreateOptions { + temperature?: 0.1 | 0.2 | 0.3 | 0.4 | 0.5 | 0.6 | 0.7 | 0.8 | 0.9 | 1.0; + topK?: number; + systemPrompt?: string; + initialPrompts?: [ + SystemPrompt, + ...Array< UserPrompt | AssistantPrompt >, + ]; + signal?: AbortSignal; + monitor?: AICreateMonitorCallback; + } + + interface AIAssistantPromptOptions { + signal?: AbortSignal; + } + + interface AICreateMonitor extends EventTarget {} + + type AICreateMonitorCallback = ( monitor: AICreateMonitor ) => void; + + interface WindowOrWorkerGlobalScope { + ai: AI; + } + + interface AIReadableStream { + [Symbol.asyncIterator](): AsyncIterableIterator< string >; + } +} + +export type {}; \ No newline at end of file diff --git a/tests/modules/llm.js b/tests/modules/llm.js new file mode 100644 index 00000000..02fc8508 --- /dev/null +++ b/tests/modules/llm.js @@ -0,0 +1,7 @@ +import createLLMProvider from '~llms' + + +console.log('llm.js loaded!') +console.log(createLLMProvider) + +window.llms = { createLLMProvider } \ No newline at end of file diff --git a/tests/modules/utils.js b/tests/modules/utils.js index c593af7f..7dc9070d 100644 --- a/tests/modules/utils.js +++ b/tests/modules/utils.js @@ -2,4 +2,7 @@ import * as file from '~utils/file' import * as misc from '~utils/misc' import * as ffmpeg from '@ffmpeg/util' + +console.log('utils.js loaded') + window.utils = { file, ffmpeg, misc } \ No newline at end of file diff --git a/tests/units/llm.spec.ts b/tests/units/llm.spec.ts new file mode 100644 index 00000000..97d0f5ef --- /dev/null +++ b/tests/units/llm.spec.ts @@ -0,0 +1,63 @@ +import { test, expect } from '@tests/fixtures/component' +import logger from '@tests/helpers/logger' + + +test('嘗試使用 Cloudflare AI 對話', async ({ page, modules }) => { + + test.skip(!process.env.CF_ACCOUNT_ID || !process.env.CF_API_TOKEN, '請設定 CF_ACCOUNT_ID 和 CF_API_TOKEN 環境變數') + + await modules['llm'].loadToPage() + await modules['utils'].loadToPage() + + const ret = await page.evaluate(async () => { + const { llms } = window as any + console.log('llms: ', llms) + const llm = await llms.createLLMProvider('cloudflare', + process.env.CF_ACCOUNT_ID, + process.env.CF_API_TOKEN + ) + return await llm.prompt('你好') + }) + + logger.info('response: ', ret) + await expect(ret).not.toBeEmpty() +}) + +test('嘗試使用 Gemini Nano 對話', async ({ page, modules }) => { + + const supported = await page.evaluate(async () => { + return !!window.ai; + }) + + test.skip(!supported, 'Gemini Nano 不支援此瀏覽器') + + await modules['llm'].loadToPage() + await modules['utils'].loadToPage() + + const ret = await page.evaluate(async () => { + const { llms } = window as any + console.log('llms: ', llms) + const llm = await llms.createLLMProvider('nano') + return await llm.prompt('你好') + }) + + logger.info('response: ', ret) + await expect(ret).not.toBeEmpty() +}) + +test('嘗試使用 Remote Worker 對話', async ({ page, modules }) => { + + await modules['llm'].loadToPage() + await modules['utils'].loadToPage() + + const ret = await page.evaluate(async () => { + const { llms } = window as any + console.log('llms: ', llms) + const llm = await llms.createLLMProvider('worker') + return await llm.prompt('你好') + }) + + logger.info('response: ', ret) + await expect(ret).not.toBeEmpty() + +}) \ No newline at end of file From 81c15ada79c26c01d9550ba2b38d5844e39f12ae Mon Sep 17 00:00:00 2001 From: eric2788 Date: Sat, 19 Oct 2024 21:28:40 +0800 Subject: [PATCH 02/15] finished llm providers --- src/api/cloudflare.ts | 43 ++++++++++++++++++++++++++ src/llms/cf-qwen.ts | 45 +++++++++++++++++++++++++++ src/llms/cloudflare-ai.ts | 59 ----------------------------------- src/llms/index.ts | 10 +++--- src/llms/remote-worker.ts | 2 ++ tests/fixtures/component.ts | 14 ++++++++- tests/units/llm.spec.ts | 61 +++++++++++++++++++++---------------- 7 files changed, 143 insertions(+), 91 deletions(-) create mode 100644 src/api/cloudflare.ts create mode 100644 src/llms/cf-qwen.ts delete mode 100644 src/llms/cloudflare-ai.ts diff --git a/src/api/cloudflare.ts b/src/api/cloudflare.ts new file mode 100644 index 00000000..9f352ef7 --- /dev/null +++ b/src/api/cloudflare.ts @@ -0,0 +1,43 @@ +import type { AIResponse, Result } from "~types/cloudflare"; + +const BASE_URL = 'https://api.cloudflare.com/client/v4' + +export async function runAI(data: any, { token, account, model }: { token: string, account: string, model: string }): Promise> { + const res = await fetch(`${BASE_URL}/accounts/${account}/ai/run/${model}`, { + method: 'POST', + headers: { + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ ...data, stream: false }) + }) + return await res.json() +} + +export async function *runAIStream(data: any, { token, account, model }: { token: string, account: string, model: string }): AsyncGenerator { + const res = await fetch(`${BASE_URL}/accounts/${account}/ai/run/${model}`, { + method: 'POST', + headers: { + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ ...data, stream: true }) + }) + if (!res.body) throw new Error('Cloudflare AI response body is not readable') + const reader = res.body.getReader() + const decoder = new TextDecoder('utf-8', { ignoreBOM: true }) + while (true) { + const { done, value } = await reader.read() + if (done) break + const { response } = JSON.parse(decoder.decode(value, { stream: true })) + yield response + } +} + +export async function validateAIToken(accountId: string, token: string): Promise { + const res = await fetch(`${BASE_URL}/accounts/${accountId}/ai/models/search?per_page=1`, { + headers: { + Authorization: `Bearer ${this.apiToken}` + } + }) + const data = await res.json() as Result + return data.success +} \ No newline at end of file diff --git a/src/llms/cf-qwen.ts b/src/llms/cf-qwen.ts new file mode 100644 index 00000000..c68420b1 --- /dev/null +++ b/src/llms/cf-qwen.ts @@ -0,0 +1,45 @@ +import { runAI, runAIStream, validateAIToken } from "~api/cloudflare"; +import type { LLMProviders, Session } from "~llms"; + +export default class CloudFlareQwen implements LLMProviders { + + private static readonly MODEL: string = '@cf/qwen/qwen1.5-14b-chat-awq' + + constructor( + private readonly accountId: string, + private readonly apiToken: string, + ) { } + + async validate(): Promise { + const success = await validateAIToken(this.accountId, this.apiToken) + if (!success) throw new Error('Cloudflare API 验证失败') + } + + async prompt(chat: string): Promise { + const res = await runAI(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: CloudFlareQwen.MODEL }) + if (!res.result) throw new Error(res.errors.join(', ')) + return res.result.response + } + + async *promptStream(chat: string): AsyncGenerator { + return runAIStream(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: CloudFlareQwen.MODEL }) + } + + async asSession(): Promise> { + console.warn('Cloudflare AI session is not supported') + return { + ...this, + [Symbol.dispose]: () => { } + } + } + + private wrap(chat: string): any { + return { + max_tokens: 512, + prompt: chat, + temperature: 0.2 + } + } + + +} \ No newline at end of file diff --git a/src/llms/cloudflare-ai.ts b/src/llms/cloudflare-ai.ts deleted file mode 100644 index f3f2f477..00000000 --- a/src/llms/cloudflare-ai.ts +++ /dev/null @@ -1,59 +0,0 @@ -import type { LLMProviders, Session } from "~llms"; - -export default class CloudflareAI implements LLMProviders { - - constructor( - private readonly accountId: string, - private readonly apiToken: String, - private readonly model = '@cf/facebook/bart-large-cnn' // text summarization model - ) { } - - async validate(): Promise { - const res = await fetch(`https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/models/search?per_page=1`, { - headers: { - Authorization: `Bearer ${this.apiToken}` - } - }) - const json = await res.json() - if (!json.success) throw new Error('Cloudflare API 验证失败') - } - - async prompt(chat: string): Promise { - const res = await fetch(`https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/run/${this.model}`, { - headers: { - Authorization: `Bearer ${this.apiToken}` - }, - body: JSON.stringify({ prompt: chat }) - }) - const json = await res.json() - return json.response - } - - async *promptStream(chat: string): AsyncGenerator { - const res = await fetch(`https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/run/${this.model}`, { - headers: { - Authorization: `Bearer ${this.apiToken}` - }, - body: JSON.stringify({ prompt: chat, stream: true }) - }) - if (!res.body) throw new Error('Cloudflare AI response body is not readable') - const reader = res.body.getReader() - const decoder = new TextDecoder('utf-8', { ignoreBOM: true }) - while (true) { - const { done, value } = await reader.read() - if (done) break - const { response } = JSON.parse(decoder.decode(value, { stream: true })) - yield response - } - } - - async asSession(): Promise> { - console.warn('Cloudflare AI session is not supported') - return { - ...this, - [Symbol.dispose]: () => { } - } - } - - -} \ No newline at end of file diff --git a/src/llms/index.ts b/src/llms/index.ts index 92373757..92f76417 100644 --- a/src/llms/index.ts +++ b/src/llms/index.ts @@ -1,4 +1,4 @@ -import cloudflare from './cloudflare-ai' +import qwen from './cf-qwen' import nano from './gemini-nano' import worker from './remote-worker' @@ -12,14 +12,16 @@ export interface LLMProviders { export type Session = Disposable & Omit const llms = { - cloudflare, + qwen, nano, worker } -export type LLMTypes = keyof typeof llms +export type LLMs = typeof llms -export async function createLLMProvider(type: LLMTypes, ...args: any[]): Promise { +export type LLMTypes = keyof LLMs + +async function createLLMProvider(type: K, ...args: ConstructorParameters): Promise { const LLM = llms[type].bind(this, ...args) return new LLM() } diff --git a/src/llms/remote-worker.ts b/src/llms/remote-worker.ts index 42a4a8ba..2d399c70 100644 --- a/src/llms/remote-worker.ts +++ b/src/llms/remote-worker.ts @@ -20,6 +20,7 @@ export default class RemoteWorker implements LLMProviders { }, body: JSON.stringify({ prompt: chat }) }) + if (!res.ok) throw new Error(await res.text()) const json = await res.json() return json.response } @@ -32,6 +33,7 @@ export default class RemoteWorker implements LLMProviders { }, body: JSON.stringify({ prompt: chat, stream: true }) }) + if (!res.ok) throw new Error(await res.text()) if (!res.body) throw new Error('Remote worker response body is not readable') const reader = res.body.getReader() const decoder = new TextDecoder('utf-8', { ignoreBOM: true }) diff --git a/tests/fixtures/component.ts b/tests/fixtures/component.ts index a679591d..76ce3e01 100644 --- a/tests/fixtures/component.ts +++ b/tests/fixtures/component.ts @@ -6,6 +6,7 @@ import type { StreamUrls } from "~background/messages/get-stream-urls" import { Strategy } from "@tests/utils/misc" import type { LiveRoomInfo } from "@tests/helpers/bilibili-api" import fs from 'fs/promises' +import { chromium } from "@playwright/test" export type IntegrationFixtures = { modules: Record @@ -72,7 +73,18 @@ export const test = base.extend({ } test.skip(!selected || !stream || stream.length === 0, '无法获取直播流') await use({ stream, ...selected }) - } + }, + + // context: async ({ }, use) => { + // const context = await chromium.launchPersistentContext('', { + // headless: false, + // args: [ + // `--disable-web-security`, + // ], + // }); + // await use(context); + // await context.close(); + // }, }) export const expect = test.expect diff --git a/tests/units/llm.spec.ts b/tests/units/llm.spec.ts index 97d0f5ef..aebe3413 100644 --- a/tests/units/llm.spec.ts +++ b/tests/units/llm.spec.ts @@ -1,29 +1,32 @@ import { test, expect } from '@tests/fixtures/component' import logger from '@tests/helpers/logger' +import createLLMProvider from "~llms" -test('嘗試使用 Cloudflare AI 對話', async ({ page, modules }) => { +test('嘗試使用 Cloudflare AI 對話', { tag: "@scoped" }, async () => { test.skip(!process.env.CF_ACCOUNT_ID || !process.env.CF_API_TOKEN, '請設定 CF_ACCOUNT_ID 和 CF_API_TOKEN 環境變數') - await modules['llm'].loadToPage() - await modules['utils'].loadToPage() + // await modules['llm'].loadToPage() + // await modules['utils'].loadToPage() - const ret = await page.evaluate(async () => { - const { llms } = window as any - console.log('llms: ', llms) - const llm = await llms.createLLMProvider('cloudflare', - process.env.CF_ACCOUNT_ID, - process.env.CF_API_TOKEN - ) - return await llm.prompt('你好') - }) + // const res = await page.evaluate(async ({ accountId, apiToken }) => { + // const { llms } = window as any + // console.log('llms: ', llms) + // const llm = await llms.createLLMProvider('cloudflare', accountId, apiToken) + // return await llm.prompt('你好') + // }, { accountId: process.env.CF_ACCOUNT_ID, apiToken: process.env.CF_API_TOKEN }) + + const llm = await createLLMProvider('qwen', process.env.CF_ACCOUNT_ID, process.env.CF_API_TOKEN) + const res = await llm.prompt('你好') + + logger.info('response: ', res) + expect(res).not.toBeUndefined() + expect(res).not.toBe('') - logger.info('response: ', ret) - await expect(ret).not.toBeEmpty() }) -test('嘗試使用 Gemini Nano 對話', async ({ page, modules }) => { +test('嘗試使用 Gemini Nano 對話', { tag: "@scoped" }, async ({ page, modules }) => { const supported = await page.evaluate(async () => { return !!window.ai; @@ -45,19 +48,23 @@ test('嘗試使用 Gemini Nano 對話', async ({ page, modules }) => { await expect(ret).not.toBeEmpty() }) -test('嘗試使用 Remote Worker 對話', async ({ page, modules }) => { +test('嘗試使用 Remote Worker 對話', { tag: "@scoped" }, async () => { - await modules['llm'].loadToPage() - await modules['utils'].loadToPage() + // await modules['llm'].loadToPage() + // await modules['utils'].loadToPage() - const ret = await page.evaluate(async () => { - const { llms } = window as any - console.log('llms: ', llms) - const llm = await llms.createLLMProvider('worker') - return await llm.prompt('你好') - }) + // const res = await page.evaluate(async () => { + // const { llms } = window as any + // console.log('llms: ', llms) + // const llm = await llms.createLLMProvider('worker') + // return await llm.prompt('你好') + // }) + + const llm = await createLLMProvider('worker') + const res = await llm.prompt('你好') + + logger.info('response: ', res) + expect(res).not.toBeUndefined() + expect(res).not.toBe('') - logger.info('response: ', ret) - await expect(ret).not.toBeEmpty() - }) \ No newline at end of file From 99174e9242f198572ab14c0da679f8329720a16f Mon Sep 17 00:00:00 2001 From: eric2788 Date: Sat, 19 Oct 2024 00:29:36 +0800 Subject: [PATCH 03/15] added test cases for ai feature --- src/components/ChatBubble.tsx | 45 +++++++++++++++ src/llms/gemini-nano.ts | 16 ++++-- src/types/cloudflare/index.ts | 8 +++ src/types/cloudflare/workers-ai.ts | 5 ++ tests/features/jimaku.spec.ts | 92 +++++++++++++++++++++++++++++- tests/modules/llm.js | 4 -- tests/units/llm.spec.ts | 48 +++++++++------- 7 files changed, 189 insertions(+), 29 deletions(-) create mode 100644 src/components/ChatBubble.tsx create mode 100644 src/types/cloudflare/index.ts create mode 100644 src/types/cloudflare/workers-ai.ts diff --git a/src/components/ChatBubble.tsx b/src/components/ChatBubble.tsx new file mode 100644 index 00000000..fc57bb86 --- /dev/null +++ b/src/components/ChatBubble.tsx @@ -0,0 +1,45 @@ +import { Avatar } from "@material-tailwind/react" +import type { ReactNode } from "react" + + + +export type ChatBubbleProps = { + avatar: string + name: string + messages: Chat[] + loading?: boolean + footer?: ReactNode +} + +export type Chat = { + text: ReactNode + time?: string +} + + +function ChatBubble(props: ChatBubbleProps): JSX.Element { + const { avatar, name, messages, loading, footer } = props + return ( +
+ +
+
{name}
+ {messages.map((message, index) => ( +
+
+
{message.text}
+
+ {message.time && ( +
+
{message.time}
+
+ )} +
+ ))} + {footer} +
+
+ ) +} + +export default ChatBubble \ No newline at end of file diff --git a/src/llms/gemini-nano.ts b/src/llms/gemini-nano.ts index 7ab3fd4a..8d2e79a0 100644 --- a/src/llms/gemini-nano.ts +++ b/src/llms/gemini-nano.ts @@ -11,13 +11,21 @@ export default class GeminiNano implements LLMProviders { } async prompt(chat: string): Promise { - using session = await this.asSession() - return session.prompt(chat) + const session = await this.asSession() + try { + return session.prompt(chat) + } finally { + session[Symbol.dispose]() + } } async *promptStream(chat: string): AsyncGenerator { - using session = await this.asSession() - return session.promptStream(chat) + const session = await this.asSession() + try { + return session.promptStream(chat) + } finally { + session[Symbol.dispose]() + } } async asSession(): Promise> { diff --git a/src/types/cloudflare/index.ts b/src/types/cloudflare/index.ts new file mode 100644 index 00000000..34351d67 --- /dev/null +++ b/src/types/cloudflare/index.ts @@ -0,0 +1,8 @@ +export * from './workers-ai' + +export type Result = { + success: boolean + result: T + errors: string[] + messages: string[] +} \ No newline at end of file diff --git a/src/types/cloudflare/workers-ai.ts b/src/types/cloudflare/workers-ai.ts new file mode 100644 index 00000000..58d10eb6 --- /dev/null +++ b/src/types/cloudflare/workers-ai.ts @@ -0,0 +1,5 @@ + + +export type AIResponse = { + response: string +} \ No newline at end of file diff --git a/tests/features/jimaku.spec.ts b/tests/features/jimaku.spec.ts index 357ff676..ef1b8060 100644 --- a/tests/features/jimaku.spec.ts +++ b/tests/features/jimaku.spec.ts @@ -125,6 +125,96 @@ test('測試彈出同傳視窗', async ({ room, context, optionPageUrl, page, co await expect(checkbox).not.toBeChecked() }) +test('测试同传字幕AI总结', { tag: "@scoped" }, async ({ room, content: p, context, optionPageUrl, page }) => { + + test.slow() + logger.info('正在修改設定...') + const settingsPage = await context.newPage() + await settingsPage.bringToFront() + await settingsPage.goto(optionPageUrl, { waitUntil: 'domcontentloaded' }) + await settingsPage.waitForTimeout(1000) + + await settingsPage.getByText('功能设定').click() + await settingsPage.getByText('AI 设定').click() + await settingsPage.getByText('启用同传字幕AI总结').click() + await settingsPage.getByText('保存设定').click() + await settingsPage.waitForTimeout(2000) + + logger.info('正在測試AI总结...') + await page.bringToFront() + const buttonList = await getButtonList(p) + expect(buttonList.length).toBe(3) + await expect(buttonList[2]).toHaveText('同传字幕AI总结') + + + await p.locator('#subtitle-list').waitFor({ state: 'visible' }) + const conversations = [ + '大家好', + '早上好', + '知道我今天吃了什么吗?', + '是麦当劳哦!', + '"不就个麦当劳而已吗"不是啦', + '是最近那个很热门的新品', + '对,就是那个', + '然后呢, 今天久违的出门了', + '对,平时都是宅在家里的呢', + '"终于长大了"喂w', + '然后今天去了漫展来着', + '很多人呢', + '之前的我看到那么多人肯定社恐了', + '但今次意外的没有呢', + '"果然是长大了"也是呢', + '然后呢, 今天买了很多东西', + '插画啊,手办啊,周边之类的', + '荷包大出血w', + '不过觉得花上去应该值得的...吧?', + '喂,好过分啊', + '不过确实不应该花那么多钱的', + '然后呢,回家途中看到了蟑螂的尸体', + '太恶心了', + '然后把我一整天好心情搞没了w', + '"就因为一个蟑螂"对www', + '不过跟你们谈完反而心情好多了', + '谢谢大家', + '那么今天的杂谈就到这里吧', + '下次再见啦', + '拜拜~' + ] + + for (const danmaku of conversations.map(t => `主播:${t}`)) { + await room.sendDanmaku(`【${danmaku}】`) + } + await p.waitForTimeout(3000) + + let subtitleList = await p.locator('#subtitle-list > p').filter({ hasText: '主播:' }).all() + expect(subtitleList.length).toBe(conversations.length) + + const newWindow = context.waitForEvent('page', { predicate: p => p.url().includes('summarizer.html') }) + await buttonList[2].click() + const summarizer = await newWindow + await summarizer.bringToFront() + const loader = summarizer.getByText('正在加载同传字幕总结') + await expect(loader).toBeVisible() + await summarizer.waitForTimeout(3000) + + await expect(summarizer.getByText('错误')).toBeHidden({ timeout: 5000 }) + await expect(loader).toBeHidden({ timeout: 30000 }) + + logger.info('正在測試AI总結結果... (15s)') + await summarizer.waitForTimeout(15000) + await expect(summarizer.getByText('错误')).toBeHidden({ timeout: 5000 }) + const res = await summarizer.getByTestId('同传字幕总结-bubble-chat-0').locator('h5.leading-snug').textContent() + logger.debug('AI Summary:', res) + + const maybe = expect.configure({ soft: true }) + maybe(res).toMatch(/主播|日本VTuber/) + maybe(res).toMatch(/直播|观众/) + maybe(res).toContain('麦当劳') + maybe(res).toContain('漫展') + maybe(res).toContain('蟑螂') + +}) + test('測試離線記錄彈幕', async ({ room, content: p, context, optionPageUrl, page }) => { @@ -340,7 +430,7 @@ test('測試保存設定後 css 能否生效', async ({ context, content, option settingsPage.getByTestId('jimaku-position'), '置左' ) - + await settingsPage.getByTestId('jimaku-color').fill('#123456') diff --git a/tests/modules/llm.js b/tests/modules/llm.js index 02fc8508..59629261 100644 --- a/tests/modules/llm.js +++ b/tests/modules/llm.js @@ -1,7 +1,3 @@ import createLLMProvider from '~llms' - -console.log('llm.js loaded!') -console.log(createLLMProvider) - window.llms = { createLLMProvider } \ No newline at end of file diff --git a/tests/units/llm.spec.ts b/tests/units/llm.spec.ts index aebe3413..6fd13cc9 100644 --- a/tests/units/llm.spec.ts +++ b/tests/units/llm.spec.ts @@ -7,23 +7,26 @@ test('嘗試使用 Cloudflare AI 對話', { tag: "@scoped" }, async () => { test.skip(!process.env.CF_ACCOUNT_ID || !process.env.CF_API_TOKEN, '請設定 CF_ACCOUNT_ID 和 CF_API_TOKEN 環境變數') - // await modules['llm'].loadToPage() - // await modules['utils'].loadToPage() + const llm = createLLMProvider('qwen', process.env.CF_ACCOUNT_ID, process.env.CF_API_TOKEN) - // const res = await page.evaluate(async ({ accountId, apiToken }) => { - // const { llms } = window as any - // console.log('llms: ', llms) - // const llm = await llms.createLLMProvider('cloudflare', accountId, apiToken) - // return await llm.prompt('你好') - // }, { accountId: process.env.CF_ACCOUNT_ID, apiToken: process.env.CF_API_TOKEN }) - - const llm = await createLLMProvider('qwen', process.env.CF_ACCOUNT_ID, process.env.CF_API_TOKEN) + logger.info('正在测试 json 返回请求...') const res = await llm.prompt('你好') logger.info('response: ', res) expect(res).not.toBeUndefined() expect(res).not.toBe('') + logger.info('正在测试 SSE 请求...') + const res2 = llm.promptStream('地球为什么是圆的?') + + let msg = ''; + for await (const r of res2) { + logger.info('response: ', r) + msg += r + } + + expect(msg).not.toBeUndefined() + expect(msg).not.toBe('') }) test('嘗試使用 Gemini Nano 對話', { tag: "@scoped" }, async ({ page, modules }) => { @@ -32,6 +35,7 @@ test('嘗試使用 Gemini Nano 對話', { tag: "@scoped" }, async ({ page, modul return !!window.ai; }) + logger.debug('Gemini Nano supported: ', supported) test.skip(!supported, 'Gemini Nano 不支援此瀏覽器') await modules['llm'].loadToPage() @@ -50,21 +54,25 @@ test('嘗試使用 Gemini Nano 對話', { tag: "@scoped" }, async ({ page, modul test('嘗試使用 Remote Worker 對話', { tag: "@scoped" }, async () => { - // await modules['llm'].loadToPage() - // await modules['utils'].loadToPage() - - // const res = await page.evaluate(async () => { - // const { llms } = window as any - // console.log('llms: ', llms) - // const llm = await llms.createLLMProvider('worker') - // return await llm.prompt('你好') - // }) + const llm = createLLMProvider('worker') - const llm = await createLLMProvider('worker') + logger.info('正在测试 json 返回请求...') const res = await llm.prompt('你好') logger.info('response: ', res) expect(res).not.toBeUndefined() expect(res).not.toBe('') + logger.info('正在测试 SSE 请求...') + const res2 = llm.promptStream('地球为什么是圆的?') + + let msg = ''; + for await (const r of res2) { + logger.info('response: ', r) + msg += r + } + + expect(msg).not.toBeUndefined() + expect(msg).not.toBe('') + }) \ No newline at end of file From acb3b0f085e8e3febcc47b8b36d0c6d62eaf03f8 Mon Sep 17 00:00:00 2001 From: eric2788 Date: Sat, 19 Oct 2024 22:58:16 +0800 Subject: [PATCH 04/15] finished ai settings page --- src/options/components/Hints.tsx | 1 - src/options/components/Selector.tsx | 3 +- .../features/jimaku/components/AIFragment.tsx | 136 ++++++++++++++++++ src/options/features/jimaku/index.tsx | 10 +- 4 files changed, 147 insertions(+), 3 deletions(-) create mode 100644 src/options/features/jimaku/components/AIFragment.tsx diff --git a/src/options/components/Hints.tsx b/src/options/components/Hints.tsx index 7cd1a6f0..5a2d7197 100644 --- a/src/options/components/Hints.tsx +++ b/src/options/components/Hints.tsx @@ -16,7 +16,6 @@ function Hints(props: HintsProps): JSX.Element { = { label: string disabled?: boolean options: SelectorOption[] + className?: string } @@ -32,7 +33,7 @@ function Selector(props: SelectorProps): JSX.Element { } return ( -
+
!props.disabled && setOpen(!isOpen)}>
diff --git a/src/options/features/jimaku/components/AIFragment.tsx b/src/options/features/jimaku/components/AIFragment.tsx new file mode 100644 index 00000000..dd403536 --- /dev/null +++ b/src/options/features/jimaku/components/AIFragment.tsx @@ -0,0 +1,136 @@ +import { Button, Input, List, Tooltip, Typography } from "@material-tailwind/react" +import { type ChangeEvent, Fragment, useState } from "react" +import { toast } from "sonner/dist" +import type { StateProxy } from "~hooks/binding" +import type { LLMProviders, LLMTypes } from "~llms" +import createLLMProvider from "~llms" +import ExperienmentFeatureIcon from "~options/components/ExperientmentFeatureIcon" +import Selector from "~options/components/Selector" +import SwitchListItem from "~options/components/SwitchListItem" + + + +export type AISchema = { + enabled: boolean + provider: LLMTypes + + // cloudflare settings + accountId?: string + apiToken?: string +} + + +export const aiDefaultSettings: Readonly = { + enabled: false, + provider: 'worker' +} + + +function AIFragment({ state, useHandler }: StateProxy): JSX.Element { + + const [validating, setValidating] = useState(false) + + const handler = useHandler, string>((e) => e.target.value) + const checker = useHandler, boolean>((e) => e.target.checked) + + const onValidate = async () => { + setValidating(true) + try { + let provider: LLMProviders; + if (state.provider === 'qwen') { + provider = await createLLMProvider(state.provider, state.accountId, state.apiToken) + } else { + provider = await createLLMProvider(state.provider) + } + await provider.validate() + toast.success('配置可用!') + } catch (e) { + toast.error('配置不可用: ' + e.message) + } finally { + setValidating(false) + } + } + + return ( + + + } + /> + + {state.enabled && ( + + + className="col-span-2" + data-testid="ai-provider" + label="AI 提供商" + value={state.provider} + onChange={e => state.provider = e} + options={[ + { label: 'Cloudflare AI', value: 'qwen' }, + { label: '有限度服务器', value: 'worker' }, + { label: 'Chrome 浏览器内置 AI', value: 'nano' } + ]} + /> + {state.provider === 'qwen' && ( + + + + + + 点击此处 + 查看如何获得 Cloudflare API Token 和 Account ID + + + + + )} + + )} +
+ +
+
+ ) +} + +export default AIFragment \ No newline at end of file diff --git a/src/options/features/jimaku/index.tsx b/src/options/features/jimaku/index.tsx index fc3175c9..778b9a2e 100644 --- a/src/options/features/jimaku/index.tsx +++ b/src/options/features/jimaku/index.tsx @@ -10,6 +10,7 @@ import ButtonFragment, { buttonDefaultSettings, type ButtonSchema } from "./comp import DanmakuZone, { danmakuDefaultSettings, type DanmakuSchema } from "./components/DanmakuFragment" import JimakuZone, { jimakuDefaultSettings, type JimakuSchema } from "./components/JimakuFragment" import ListingFragment, { listingDefaultSettings, type ListingSchema } from "./components/ListingFragment" +import AIFragment, { aiDefaultSettings, type AISchema } from "./components/AIFragment" export const title: string = '同传弹幕过滤' @@ -28,6 +29,7 @@ export type FeatureSettingSchema = { danmakuZone: DanmakuSchema, buttonZone: ButtonSchema, listingZone: ListingSchema + aiZone: AISchema } export const defaultSettings: Readonly = { @@ -39,7 +41,8 @@ export const defaultSettings: Readonly = { jimakuZone: jimakuDefaultSettings, danmakuZone: danmakuDefaultSettings, buttonZone: buttonDefaultSettings, - listingZone: listingDefaultSettings + listingZone: listingDefaultSettings, + aiZone: aiDefaultSettings } const zones: { @@ -66,6 +69,11 @@ const zones: { Zone: ListingFragment, title: '同传名单设定', key: 'listingZone' + }, + { + Zone: AIFragment, + title: 'AI 设定', + key: 'aiZone' } ] From fd57ba5a2fe70c61eb015e6f6935a5ecec5c5044 Mon Sep 17 00:00:00 2001 From: eric2788 Date: Sun, 20 Oct 2024 22:30:58 +0800 Subject: [PATCH 05/15] finished frontend and extension page --- package.json | 3 +- src/api/cloudflare.ts | 15 +-- src/background/forwards.ts | 4 +- src/background/forwards/summerize.ts | 8 ++ src/background/messages/open-tab.ts | 18 ++-- src/features/jimaku/components/ButtonArea.tsx | 21 ++++- src/llms/index.ts | 2 +- src/llms/remote-worker.ts | 7 +- .../features/jimaku/components/AIFragment.tsx | 6 +- src/tabs/summarizer.tsx | 94 +++++++++++++++++++ src/utils/binary.ts | 22 +++++ src/utils/fetch.ts | 2 +- 12 files changed, 177 insertions(+), 25 deletions(-) create mode 100644 src/background/forwards/summerize.ts create mode 100644 src/tabs/summarizer.tsx diff --git a/package.json b/package.json index eca8d926..3fe907c8 100644 --- a/package.json +++ b/package.json @@ -82,7 +82,8 @@ "*://api.live.bilibili.com/*", "*://live.bilibili.com/*", "*://*.bilivideo.com/*", - "*://*.ericlamm.xyz/*" + "*://*.ericlamm.xyz/*", + "*://*.cloudflare.com/*" ], "permissions": [ "notifications", diff --git a/src/api/cloudflare.ts b/src/api/cloudflare.ts index 9f352ef7..9cba0130 100644 --- a/src/api/cloudflare.ts +++ b/src/api/cloudflare.ts @@ -1,4 +1,5 @@ import type { AIResponse, Result } from "~types/cloudflare"; +import { parseSSEResponses } from "~utils/binary"; const BASE_URL = 'https://api.cloudflare.com/client/v4' @@ -10,7 +11,9 @@ export async function runAI(data: any, { token, account, model }: { token: strin }, body: JSON.stringify({ ...data, stream: false }) }) - return await res.json() + const json = await res.json() as Result + if (!res.ok) throw new Error(json.errors.join('\n')) + return json } export async function *runAIStream(data: any, { token, account, model }: { token: string, account: string, model: string }): AsyncGenerator { @@ -21,13 +24,13 @@ export async function *runAIStream(data: any, { token, account, model }: { token }, body: JSON.stringify({ ...data, stream: true }) }) + if (!res.ok) { + const json = await res.json() as Result + throw new Error(json.errors.join('\n')) + } if (!res.body) throw new Error('Cloudflare AI response body is not readable') const reader = res.body.getReader() - const decoder = new TextDecoder('utf-8', { ignoreBOM: true }) - while (true) { - const { done, value } = await reader.read() - if (done) break - const { response } = JSON.parse(decoder.decode(value, { stream: true })) + for await (const response of parseSSEResponses(reader, '[DONE]')) { yield response } } diff --git a/src/background/forwards.ts b/src/background/forwards.ts index 07099e7f..a0d57803 100644 --- a/src/background/forwards.ts +++ b/src/background/forwards.ts @@ -4,6 +4,7 @@ import * as danmaku from './forwards/danmaku' import * as jimaku from './forwards/jimaku' import * as redirect from './forwards/redirect' import * as streamContent from './forwards/stream-content' +import * as jimakuSummarize from './forwards/summerize' export type ForwardData = typeof forwards @@ -153,5 +154,6 @@ const forwards = { 'redirect': redirect, 'danmaku': danmaku, 'blive-data': bliveData, - 'stream-content': streamContent + 'stream-content': streamContent, + 'jimaku-summarize': jimakuSummarize } diff --git a/src/background/forwards/summerize.ts b/src/background/forwards/summerize.ts new file mode 100644 index 00000000..1ee9248a --- /dev/null +++ b/src/background/forwards/summerize.ts @@ -0,0 +1,8 @@ +import { useDefaultHandler } from "~background/forwards" + +export type ForwardBody = { + roomId: string + jimakus: string[] +} + +export default useDefaultHandler() \ No newline at end of file diff --git a/src/background/messages/open-tab.ts b/src/background/messages/open-tab.ts index 8c74473a..36d06fb4 100644 --- a/src/background/messages/open-tab.ts +++ b/src/background/messages/open-tab.ts @@ -5,17 +5,23 @@ export type RequestBody = { tab?: string active?: boolean params?: Record + singleton?: boolean } const handler: PlasmoMessaging.MessageHandler = async (req, res) => { const { url, tab, active } = req.body const queryString = req.body.params ? `?${new URLSearchParams(req.body.params).toString()}` : '' - const result = await chrome.tabs.create({ - url: tab ? - chrome.runtime.getURL(`/tabs/${tab}.html${queryString}`) : - url + queryString, - active - }) + const fullUrl = tab ? chrome.runtime.getURL(`/tabs/${tab}.html${queryString}`) : url + queryString + if (req.body.singleton) { + const tabs = await chrome.tabs.query({ url: fullUrl }) + if (tabs.length) { + const tab = tabs[0] + await chrome.tabs.update(tab.id, { active: true }) + res.send(tab) + return + } + } + const result = await chrome.tabs.create({ url: fullUrl, active }) res.send(result) } diff --git a/src/features/jimaku/components/ButtonArea.tsx b/src/features/jimaku/components/ButtonArea.tsx index b5d6169b..f677e2a9 100644 --- a/src/features/jimaku/components/ButtonArea.tsx +++ b/src/features/jimaku/components/ButtonArea.tsx @@ -8,6 +8,10 @@ import type { Jimaku } from "./JimakuLine"; import { createPortal } from "react-dom"; import ButtonSwitchList from "./ButtonSwitchList"; import TailwindScope from "~components/TailwindScope"; +import { toast } from "sonner/dist"; +import { sendMessager } from "~utils/messaging"; +import { sendForward } from "~background/forwards"; +import { sleep } from "~utils/misc"; export type ButtonAreaProps = { clearJimaku: VoidFunction @@ -17,7 +21,7 @@ export type ButtonAreaProps = { function ButtonArea({ clearJimaku, jimakus }: ButtonAreaProps): JSX.Element { const { settings, info } = useContext(ContentContext) - const { jimakuZone, buttonZone: btnStyle, jimakuPopupWindow } = useContext(JimakuFeatureContext) + const { jimakuZone, buttonZone: btnStyle, jimakuPopupWindow, aiZone } = useContext(JimakuFeatureContext) const { order } = jimakuZone const { enabledRecording } = settings["settings.features"] @@ -36,6 +40,16 @@ function ButtonArea({ clearJimaku, jimakus }: ButtonAreaProps): JSX.Element { const [show, setShow] = useState(!info.isTheme) + const summerize = async () => { + if (jimakus.length < 10) { + toast.warning('至少需要有10条同传字幕才可总结。') + return + } + await sendMessager('open-tab', { tab: 'summarizer', params: { roomId: info.room, title: info.title }, active: true, singleton: true }) + await sleep(2000) + sendForward('pages', 'jimaku-summarize', { roomId: info.room, jimakus: jimakus.map(j => j.text) }) + } + return ( {show && ( @@ -59,6 +73,11 @@ function ButtonArea({ clearJimaku, jimakus }: ButtonAreaProps): JSX.Element { 弹出同传视窗 } + {aiZone.enabled && ( + + 同传字幕AI总结 + + )}
)} {info.isTheme && document.querySelector(upperHeaderArea) !== null && createPortal( diff --git a/src/llms/index.ts b/src/llms/index.ts index 92f76417..52675e43 100644 --- a/src/llms/index.ts +++ b/src/llms/index.ts @@ -21,7 +21,7 @@ export type LLMs = typeof llms export type LLMTypes = keyof LLMs -async function createLLMProvider(type: K, ...args: ConstructorParameters): Promise { +function createLLMProvider(type: K, ...args: ConstructorParameters): LLMProviders { const LLM = llms[type].bind(this, ...args) return new LLM() } diff --git a/src/llms/remote-worker.ts b/src/llms/remote-worker.ts index 2d399c70..9293db38 100644 --- a/src/llms/remote-worker.ts +++ b/src/llms/remote-worker.ts @@ -1,4 +1,5 @@ import type { LLMProviders, Session } from "~llms"; +import { parseSSEResponses } from "~utils/binary"; // for my worker, so limited usage @@ -36,11 +37,7 @@ export default class RemoteWorker implements LLMProviders { if (!res.ok) throw new Error(await res.text()) if (!res.body) throw new Error('Remote worker response body is not readable') const reader = res.body.getReader() - const decoder = new TextDecoder('utf-8', { ignoreBOM: true }) - while (true) { - const { done, value } = await reader.read() - if (done) break - const { response } = JSON.parse(decoder.decode(value, { stream: true })) + for await (const response of parseSSEResponses(reader, '[DONE]')) { yield response } } diff --git a/src/options/features/jimaku/components/AIFragment.tsx b/src/options/features/jimaku/components/AIFragment.tsx index dd403536..ec0eda64 100644 --- a/src/options/features/jimaku/components/AIFragment.tsx +++ b/src/options/features/jimaku/components/AIFragment.tsx @@ -38,9 +38,9 @@ function AIFragment({ state, useHandler }: StateProxy): JSX.Element { try { let provider: LLMProviders; if (state.provider === 'qwen') { - provider = await createLLMProvider(state.provider, state.accountId, state.apiToken) + provider = createLLMProvider(state.provider, state.accountId, state.apiToken) } else { - provider = await createLLMProvider(state.provider) + provider = createLLMProvider(state.provider) } await provider.validate() toast.success('配置可用!') @@ -68,7 +68,7 @@ function AIFragment({ state, useHandler }: StateProxy): JSX.Element { className="col-span-2" data-testid="ai-provider" - label="AI 提供商" + label="技术来源" value={state.provider} onChange={e => state.provider = e} options={[ diff --git a/src/tabs/summarizer.tsx b/src/tabs/summarizer.tsx new file mode 100644 index 00000000..fe9c1340 --- /dev/null +++ b/src/tabs/summarizer.tsx @@ -0,0 +1,94 @@ +import { Typography } from "@material-tailwind/react"; +import icon from 'raw:assets/icon.png'; +import { useCallback, useDeferredValue, useEffect, useState } from "react"; +import { useForwarder } from "~hooks/forwarder"; + +import ChatBubble from "~components/ChatBubble"; +import createLLMProvider, { type LLMProviders } from "~llms"; +import type { AISchema } from "~options/features/jimaku/components/AIFragment"; +import '~style.css'; +import { getSettingStorage } from "~utils/storage"; + +const urlParams = new URLSearchParams(window.location.search); +const roomId = urlParams.get('roomId') +const roomTitle = urlParams.get('title') + +function createLLM(schema: AISchema): LLMProviders { + switch (schema.provider) { + case 'worker': + case 'nano': + return createLLMProvider(schema.provider) + case 'qwen': + return createLLMProvider(schema.provider, schema.accountId, schema.apiToken) + } +} + +const loadingText = '正在加载同传字幕总结.....' + +function App() { + + const [title, setTitle] = useState('加载中') + const [loading, setLoading] = useState(true) + const [summary, setSummary] = useState(loadingText) + const [error, setError] = useState('') + const deferredSummary = useDeferredValue(summary) + const forwarder = useForwarder('jimaku-summarize', 'pages') + + useEffect(() => { + if (!roomId) { + alert('未指定房间号') + return + } + setTitle(roomTitle ?? `B站直播间 ${roomId}`) + const remover = forwarder.addHandler((data) => { + if (data.roomId !== roomId) return + console.debug('received ', data.jimakus.length, 'danmakus') + summarize(data.jimakus) + remover() + }) + return remover + }, []) + + const summarize = useCallback(async (danmakus: string[]) => { + try { + if (danmakus.length < 10) { + throw new Error('至少需要有10条同传字幕才可总结。') + } + const { jimaku: { aiZone } } = await getSettingStorage('settings.features') + const llm = createLLM(aiZone) + const summaryStream = llm.promptStream(`这位是一名在b站直播间直播的日本vtuber说过的话,请根据下文对话猜测与观众的互动内容,并用中文总结一下他们的对话:\n\n${danmakus.join('\n')}`) + setLoading(false) + for await (const words of summaryStream) { + setSummary(summary => summary === loadingText ? words : summary + words) + } + } catch (err) { + setLoading(false) + console.error(err) + setError('未知错误: ' + err.message) + } + }, []) + + return ( +
+
+ {title} +
+
+
+
+ {error}} + /> +
+
+
+
+ ) +} + + +export default App; \ No newline at end of file diff --git a/src/utils/binary.ts b/src/utils/binary.ts index ebfed598..2c2dfa45 100644 --- a/src/utils/binary.ts +++ b/src/utils/binary.ts @@ -104,4 +104,26 @@ export function toArrayBuffer(like: ArrayBufferLike): ArrayBuffer { const arr = new Uint8Array(new ArrayBuffer(like.byteLength)) arr.set(new Uint8Array(like), 0) return arr.buffer +} + +export async function* parseSSEResponses(reader: ReadableStreamDefaultReader>, endStr?: string): AsyncGenerator { + const decoder = new TextDecoder('utf-8', { ignoreBOM: true }) + while (true) { + const { done, value } = await reader.read() + if (done) break + const decoded = decoder.decode(value, { stream: true }) + const textValues = decoded.split('\n\n') // sometimes it will fetch multiple lines + for (const textValue of textValues) { + if (textValue.trim() === '') continue + if (!textValue.startsWith('data:')) continue + const jsonValue = textValue.slice(5).trim() + if (endStr && jsonValue === endStr) break + try { + const { response } = JSON.parse(jsonValue) + yield response + } catch (err) { + throw new Error(`error while parsing '${jsonValue}': ${err.message ?? err}`) + } + } + } } \ No newline at end of file diff --git a/src/utils/fetch.ts b/src/utils/fetch.ts index 658deca1..e09eaed8 100644 --- a/src/utils/fetch.ts +++ b/src/utils/fetch.ts @@ -151,7 +151,7 @@ export async function sendRequest(request: RequestBody): Promise { } -export type CacheInfo ={ +export type CacheInfo = { data: T timestamp: number } From 32ddd4a1ebbc21f9b6e77ac77dd27298df039c36 Mon Sep 17 00:00:00 2001 From: eric2788 Date: Mon, 21 Oct 2024 10:23:41 +0800 Subject: [PATCH 06/15] added addHandlerOnce on forwarder --- src/background/forwards.ts | 12 ++++++++++++ src/hooks/forwarder.ts | 15 +++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/background/forwards.ts b/src/background/forwards.ts index a0d57803..05812eae 100644 --- a/src/background/forwards.ts +++ b/src/background/forwards.ts @@ -90,6 +90,9 @@ export function isForwardMessage(message: any): message is For * forwarder.addHandler((data) => { * console.log('Received message:', data) * }) + * forwarder.addHandlerOnce((data) => { + * console.log('Received message:', data) + * }) * forwarder.sendForward('background', { message: 'Hello' }) */ export function getForwarder(command: K, target: ChannelType): Forwarder { @@ -132,6 +135,14 @@ export function getForwarder(command: K, target: Ch chrome.runtime.onMessage.addListener(fn) return () => chrome.runtime.onMessage.removeListener(fn) }, + addHandlerOnce: (handler: (data: R) => void): VoidCallback => { + const fn = listener((data: R) => { + handler(data) + chrome.runtime.onMessage.removeListener(fn) + }) + chrome.runtime.onMessage.addListener(fn) + return () => chrome.runtime.onMessage.removeListener(fn) + }, sendForward: (toTarget: C, body: T, queryInfo?: ChannelQueryInfo[C]): void => { sendForward(toTarget, command, body, queryInfo) } @@ -145,6 +156,7 @@ export function useDefaultHandler(): ForwardHandler { export type Forwarder = { addHandler: (handler: (data: ForwardResponse) => void) => VoidCallback + addHandlerOnce: (handler: (data: ForwardResponse) => void) => VoidCallback sendForward: (toTarget: C, body: ForwardBody, queryInfo?: ChannelQueryInfo[C]) => void } diff --git a/src/hooks/forwarder.ts b/src/hooks/forwarder.ts index 411cc633..6324d0b3 100644 --- a/src/hooks/forwarder.ts +++ b/src/hooks/forwarder.ts @@ -2,6 +2,7 @@ import { getForwarder, type ChannelType, type ForwardData, + type Forwarder, type ForwardResponse } from '~background/forwards' @@ -18,17 +19,22 @@ import { useEffect, useMemo } from 'react' * - `sendForward`: A function that sends a message with the specified command to the specified channel. The message body is passed as an argument to this function. * * @example - * const { addHandler, sendForward } = useForwarder('myCommand', 'background') + * const { addHandler, addHandlerOnce, sendForward } = useForwarder('myCommand', 'background') * * // Add a handler for 'myCommand' messages on the 'background' channel * addHandler((data) => { * console.log('Received data:', data) * }) + * + * // Add a one-time handler for 'myCommand' messages on the 'background' channel + * addHandlerOnce((data) => { + * console.log('Received data:', data) + * }) * * // Send a 'myCommand' message to the 'background' channel * sendForward('background', { myData: 'Hello, world!' }) */ -export function useForwarder(key: K, target: ChannelType) { +export function useForwarder(key: K, target: ChannelType): Forwarder { type R = ForwardResponse const removeFunc = new Set() @@ -47,6 +53,11 @@ export function useForwarder(key: K, target: Channe removeFunc.add(remover) return remover // auto remove on unmount or manual remove }, + addHandlerOnce: (handler: (data: R) => void): VoidCallback => { + const remover = forwarder.addHandlerOnce(handler) + removeFunc.add(remover) + return remover // auto remove on unmount or manual remove + }, sendForward: forwarder.sendForward }), [forwarder]) From 050d466cff7d0dd4cd184eadde9a6a2364fd08a3 Mon Sep 17 00:00:00 2001 From: eric2788 Date: Mon, 21 Oct 2024 16:13:34 +0800 Subject: [PATCH 07/15] changed partial-test retries to 3 --- .github/workflows/partial-test.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/partial-test.yml b/.github/workflows/partial-test.yml index 97fd6c1a..5365faac 100644 --- a/.github/workflows/partial-test.yml +++ b/.github/workflows/partial-test.yml @@ -80,7 +80,8 @@ jobs: --grep=@scoped \ --pass-with-no-tests \ --global-timeout=3600000 \ - --max-failures=2 + --max-failures=2 \ + --retries=3 env: DEBUG: true fast-e2e-test: @@ -122,7 +123,8 @@ jobs: --pass-with-no-tests \ --global-timeout=3600000 \ --timeout=60000 \ - --max-failures=5 + --max-failures=5 \ + --retries=3 env: DEBUG: true - name: Upload Test Results From 1a5eacc43b534208e964f75fbb972cb451f8728d Mon Sep 17 00:00:00 2001 From: eric2788 Date: Tue, 22 Oct 2024 15:51:37 +0800 Subject: [PATCH 08/15] changed least jimakus to 25, error on empty response --- src/features/jimaku/components/ButtonArea.tsx | 4 ++-- src/tabs/summarizer.tsx | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/features/jimaku/components/ButtonArea.tsx b/src/features/jimaku/components/ButtonArea.tsx index f677e2a9..1c49b3d4 100644 --- a/src/features/jimaku/components/ButtonArea.tsx +++ b/src/features/jimaku/components/ButtonArea.tsx @@ -41,8 +41,8 @@ function ButtonArea({ clearJimaku, jimakus }: ButtonAreaProps): JSX.Element { const [show, setShow] = useState(!info.isTheme) const summerize = async () => { - if (jimakus.length < 10) { - toast.warning('至少需要有10条同传字幕才可总结。') + if (jimakus.length < 25) { + toast.warning('至少需要有25条同传字幕才可总结。') return } await sendMessager('open-tab', { tab: 'summarizer', params: { roomId: info.room, title: info.title }, active: true, singleton: true }) diff --git a/src/tabs/summarizer.tsx b/src/tabs/summarizer.tsx index fe9c1340..06a0df7a 100644 --- a/src/tabs/summarizer.tsx +++ b/src/tabs/summarizer.tsx @@ -40,6 +40,7 @@ function App() { return } setTitle(roomTitle ?? `B站直播间 ${roomId}`) + // only run once after success const remover = forwarder.addHandler((data) => { if (data.roomId !== roomId) return console.debug('received ', data.jimakus.length, 'danmakus') @@ -51,9 +52,6 @@ function App() { const summarize = useCallback(async (danmakus: string[]) => { try { - if (danmakus.length < 10) { - throw new Error('至少需要有10条同传字幕才可总结。') - } const { jimaku: { aiZone } } = await getSettingStorage('settings.features') const llm = createLLM(aiZone) const summaryStream = llm.promptStream(`这位是一名在b站直播间直播的日本vtuber说过的话,请根据下文对话猜测与观众的互动内容,并用中文总结一下他们的对话:\n\n${danmakus.join('\n')}`) @@ -64,7 +62,11 @@ function App() { } catch (err) { setLoading(false) console.error(err) - setError('未知错误: ' + err.message) + setError('错误: ' + err.message) + } finally { + if (summary === '') { + setError('同传总结返回了空的回应。') + } } }, []) From ce6568385b337b1a60f4ca3ddcf3bd6b5d9266a8 Mon Sep 17 00:00:00 2001 From: eric2788 Date: Tue, 22 Oct 2024 16:23:26 +0800 Subject: [PATCH 09/15] optimized open-tab singleton mechanism --- src/background/messages/open-tab.ts | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/background/messages/open-tab.ts b/src/background/messages/open-tab.ts index 36d06fb4..9c854883 100644 --- a/src/background/messages/open-tab.ts +++ b/src/background/messages/open-tab.ts @@ -5,24 +5,26 @@ export type RequestBody = { tab?: string active?: boolean params?: Record - singleton?: boolean + singleton?: boolean | string[] } const handler: PlasmoMessaging.MessageHandler = async (req, res) => { const { url, tab, active } = req.body const queryString = req.body.params ? `?${new URLSearchParams(req.body.params).toString()}` : '' const fullUrl = tab ? chrome.runtime.getURL(`/tabs/${tab}.html${queryString}`) : url + queryString + const pathUrl = (tab ? chrome.runtime.getURL(`/tabs/${tab}.html`) : url) + '*' if (req.body.singleton) { - const tabs = await chrome.tabs.query({ url: fullUrl }) - if (tabs.length) { - const tab = tabs[0] - await chrome.tabs.update(tab.id, { active: true }) - res.send(tab) + const tabs = await chrome.tabs.query({ url: typeof req.body.singleton === 'boolean' ? fullUrl : pathUrl }) + const tab = tabs.find(tab => + typeof req.body.singleton === 'boolean' || + req.body.singleton.some(param => new URL(tab.url).searchParams.get(param) === req.body.params[param]) + ) + if (tab) { + res.send(await chrome.tabs.update(tab.id, { active: true })) return } } - const result = await chrome.tabs.create({ url: fullUrl, active }) - res.send(result) + res.send(await chrome.tabs.create({ url: fullUrl, active })) } From a2f92d7e900ef837d1733aa1796213c91ca498b0 Mon Sep 17 00:00:00 2001 From: eric2788 Date: Tue, 22 Oct 2024 18:10:25 +0800 Subject: [PATCH 10/15] reshaped ai schema and model selectable --- src/api/cloudflare.ts | 16 ++- src/features/jimaku/components/ButtonArea.tsx | 2 +- src/llms/{cf-qwen.ts => cloudflare-ai.ts} | 32 +++-- src/llms/index.ts | 13 +- src/llms/remote-worker.ts | 12 +- .../features/jimaku/components/AIFragment.tsx | 113 +-------------- src/options/fragments.ts | 2 + src/options/fragments/llm.tsx | 133 ++++++++++++++++++ src/tabs/summarizer.tsx | 20 +-- src/types/cloudflare/index.ts | 2 +- tests/features/jimaku.spec.ts | 2 +- tests/integrations/summarizer.spec.ts | 70 +++++++++ tests/units/llm.spec.ts | 10 +- 13 files changed, 277 insertions(+), 150 deletions(-) rename src/llms/{cf-qwen.ts => cloudflare-ai.ts} (50%) create mode 100644 src/options/fragments/llm.tsx create mode 100644 tests/integrations/summarizer.spec.ts diff --git a/src/api/cloudflare.ts b/src/api/cloudflare.ts index 9cba0130..321d7861 100644 --- a/src/api/cloudflare.ts +++ b/src/api/cloudflare.ts @@ -16,7 +16,7 @@ export async function runAI(data: any, { token, account, model }: { token: strin return json } -export async function *runAIStream(data: any, { token, account, model }: { token: string, account: string, model: string }): AsyncGenerator { +export async function* runAIStream(data: any, { token, account, model }: { token: string, account: string, model: string }): AsyncGenerator { const res = await fetch(`${BASE_URL}/accounts/${account}/ai/run/${model}`, { method: 'POST', headers: { @@ -35,12 +35,18 @@ export async function *runAIStream(data: any, { token, account, model }: { token } } -export async function validateAIToken(accountId: string, token: string): Promise { - const res = await fetch(`${BASE_URL}/accounts/${accountId}/ai/models/search?per_page=1`, { +export async function validateAIToken(accountId: string, token: string, model: string): Promise { + const res = await fetch(`${BASE_URL}/accounts/${accountId}/ai/models/search?search=${model}&per_page=1`, { headers: { - Authorization: `Bearer ${this.apiToken}` + Authorization: `Bearer ${token}` } }) const data = await res.json() as Result - return data.success + if (!data.success) { + return false + } else if (data.result.length === 0) { + return '找不到指定 AI 模型' + } else { + return true + } } \ No newline at end of file diff --git a/src/features/jimaku/components/ButtonArea.tsx b/src/features/jimaku/components/ButtonArea.tsx index 1c49b3d4..e30b7411 100644 --- a/src/features/jimaku/components/ButtonArea.tsx +++ b/src/features/jimaku/components/ButtonArea.tsx @@ -73,7 +73,7 @@ function ButtonArea({ clearJimaku, jimakus }: ButtonAreaProps): JSX.Element { 弹出同传视窗 } - {aiZone.enabled && ( + {aiZone.summarizeEnabled && ( 同传字幕AI总结 diff --git a/src/llms/cf-qwen.ts b/src/llms/cloudflare-ai.ts similarity index 50% rename from src/llms/cf-qwen.ts rename to src/llms/cloudflare-ai.ts index c68420b1..c122c394 100644 --- a/src/llms/cf-qwen.ts +++ b/src/llms/cloudflare-ai.ts @@ -1,28 +1,38 @@ import { runAI, runAIStream, validateAIToken } from "~api/cloudflare"; import type { LLMProviders, Session } from "~llms"; +import type { SettingSchema } from "~options/fragments/llm"; -export default class CloudFlareQwen implements LLMProviders { +export default class CloudFlareAI implements LLMProviders { - private static readonly MODEL: string = '@cf/qwen/qwen1.5-14b-chat-awq' + private static readonly DEFAULT_MODEL: string = '@cf/qwen/qwen1.5-14b-chat-awq' - constructor( - private readonly accountId: string, - private readonly apiToken: string, - ) { } + private readonly accountId: string + private readonly apiToken: string + + private readonly model: string + + constructor(settings: SettingSchema) { + this.accountId = settings.accountId + this.apiToken = settings.apiToken + + // only text generation model for now + this.model = settings.model || CloudFlareAI.DEFAULT_MODEL + } async validate(): Promise { - const success = await validateAIToken(this.accountId, this.apiToken) - if (!success) throw new Error('Cloudflare API 验证失败') + const success = await validateAIToken(this.accountId, this.apiToken, this.model) + if (typeof success === 'boolean' && !success) throw new Error('Cloudflare API 验证失败') + if (typeof success === 'string') throw new Error(success) } async prompt(chat: string): Promise { - const res = await runAI(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: CloudFlareQwen.MODEL }) + const res = await runAI(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: this.model }) if (!res.result) throw new Error(res.errors.join(', ')) return res.result.response } async *promptStream(chat: string): AsyncGenerator { - return runAIStream(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: CloudFlareQwen.MODEL }) + return runAIStream(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: this.model }) } async asSession(): Promise> { @@ -33,6 +43,8 @@ export default class CloudFlareQwen implements LLMProviders { } } + // text generation model input schema + // so only text generation model for now private wrap(chat: string): any { return { max_tokens: 512, diff --git a/src/llms/index.ts b/src/llms/index.ts index 52675e43..a7daee5f 100644 --- a/src/llms/index.ts +++ b/src/llms/index.ts @@ -1,4 +1,6 @@ -import qwen from './cf-qwen' +import type { SettingSchema as LLMSchema } from '~options/fragments/llm' + +import cloudflare from './cloudflare-ai' import nano from './gemini-nano' import worker from './remote-worker' @@ -12,7 +14,7 @@ export interface LLMProviders { export type Session = Disposable & Omit const llms = { - qwen, + cloudflare, nano, worker } @@ -21,9 +23,10 @@ export type LLMs = typeof llms export type LLMTypes = keyof LLMs -function createLLMProvider(type: K, ...args: ConstructorParameters): LLMProviders { - const LLM = llms[type].bind(this, ...args) - return new LLM() +function createLLMProvider(settings: LLMSchema): LLMProviders { + const type = settings.provider + const LLM = llms[type] + return new LLM(settings) } export default createLLMProvider \ No newline at end of file diff --git a/src/llms/remote-worker.ts b/src/llms/remote-worker.ts index 9293db38..d2f08c78 100644 --- a/src/llms/remote-worker.ts +++ b/src/llms/remote-worker.ts @@ -1,10 +1,16 @@ import type { LLMProviders, Session } from "~llms"; +import type { SettingSchema } from "~options/fragments/llm"; import { parseSSEResponses } from "~utils/binary"; - // for my worker, so limited usage export default class RemoteWorker implements LLMProviders { + private readonly model?: string + + constructor(settings: SettingSchema) { + this.model = settings.model || undefined + } + async validate(): Promise { const res = await fetch('https://llm.ericlamm.xyz/status') const json = await res.json() @@ -19,7 +25,7 @@ export default class RemoteWorker implements LLMProviders { headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ prompt: chat }) + body: JSON.stringify({ prompt: chat, model: this.model }) }) if (!res.ok) throw new Error(await res.text()) const json = await res.json() @@ -32,7 +38,7 @@ export default class RemoteWorker implements LLMProviders { headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ prompt: chat, stream: true }) + body: JSON.stringify({ prompt: chat, stream: true, model: this.model }) }) if (!res.ok) throw new Error(await res.text()) if (!res.body) throw new Error('Remote worker response body is not readable') diff --git a/src/options/features/jimaku/components/AIFragment.tsx b/src/options/features/jimaku/components/AIFragment.tsx index ec0eda64..83e48e2c 100644 --- a/src/options/features/jimaku/components/AIFragment.tsx +++ b/src/options/features/jimaku/components/AIFragment.tsx @@ -1,134 +1,35 @@ -import { Button, Input, List, Tooltip, Typography } from "@material-tailwind/react" -import { type ChangeEvent, Fragment, useState } from "react" -import { toast } from "sonner/dist" +import { List } from "@material-tailwind/react" +import { type ChangeEvent, Fragment } from "react" import type { StateProxy } from "~hooks/binding" -import type { LLMProviders, LLMTypes } from "~llms" -import createLLMProvider from "~llms" import ExperienmentFeatureIcon from "~options/components/ExperientmentFeatureIcon" -import Selector from "~options/components/Selector" import SwitchListItem from "~options/components/SwitchListItem" - - export type AISchema = { - enabled: boolean - provider: LLMTypes - - // cloudflare settings - accountId?: string - apiToken?: string + summarizeEnabled: boolean } export const aiDefaultSettings: Readonly = { - enabled: false, - provider: 'worker' + summarizeEnabled: false } function AIFragment({ state, useHandler }: StateProxy): JSX.Element { - const [validating, setValidating] = useState(false) - - const handler = useHandler, string>((e) => e.target.value) const checker = useHandler, boolean>((e) => e.target.checked) - const onValidate = async () => { - setValidating(true) - try { - let provider: LLMProviders; - if (state.provider === 'qwen') { - provider = createLLMProvider(state.provider, state.accountId, state.apiToken) - } else { - provider = createLLMProvider(state.provider) - } - await provider.validate() - toast.success('配置可用!') - } catch (e) { - toast.error('配置不可用: ' + e.message) - } finally { - setValidating(false) - } - } - return ( } /> - {state.enabled && ( - - - className="col-span-2" - data-testid="ai-provider" - label="技术来源" - value={state.provider} - onChange={e => state.provider = e} - options={[ - { label: 'Cloudflare AI', value: 'qwen' }, - { label: '有限度服务器', value: 'worker' }, - { label: 'Chrome 浏览器内置 AI', value: 'nano' } - ]} - /> - {state.provider === 'qwen' && ( - - - - - - 点击此处 - 查看如何获得 Cloudflare API Token 和 Account ID - - - - - )} - - )} -
- -
) } diff --git a/src/options/fragments.ts b/src/options/fragments.ts index 8b9a25b2..ed6228d5 100644 --- a/src/options/fragments.ts +++ b/src/options/fragments.ts @@ -5,6 +5,7 @@ import * as display from './fragments/display' import * as features from './fragments/features' import * as listings from './fragments/listings' import * as version from './fragments/version' +import * as llm from './fragments/llm' interface SettingFragment { @@ -28,6 +29,7 @@ const fragments = { 'settings.listings': listings, 'settings.capture': capture, 'settings.display': display, + 'settings.llm': llm, 'settings.developer': developer, 'settings.version': version } diff --git a/src/options/fragments/llm.tsx b/src/options/fragments/llm.tsx new file mode 100644 index 00000000..85a7261d --- /dev/null +++ b/src/options/fragments/llm.tsx @@ -0,0 +1,133 @@ +import { Button, Input, Tooltip, Typography } from "@material-tailwind/react" +import { Fragment, useState, type ChangeEvent } from "react" +import { toast } from "sonner/dist" +import type { StateProxy } from "~hooks/binding" +import type { LLMTypes } from "~llms" +import createLLMProvider from "~llms" +import Selector from "~options/components/Selector" + +export type SettingSchema = { + provider: LLMTypes + + // cloudflare settings + accountId?: string + apiToken?: string + + // cloudflare and worker settings + model?: string +} + +export const defaultSettings: Readonly = { + provider: 'worker', + model: '@cf/qwen/qwen1.5-14b-chat-awq' +} + +export const title = 'AI 模型设定' + +export const description = [ + '此设定区块包含了大语言模型(LLM)相关的设定,用于为插件提供 AI 功能。', + '技术提供默认为公共的服务器,质量可能不稳定,建议设置为 Cloudflare 作为技术提供来源。' +] + +function LLMSettings({ state, useHandler }: StateProxy): JSX.Element { + + const [validating, setValidating] = useState(false) + const handler = useHandler, string>((e) => e.target.value) + + const onValidate = async () => { + setValidating(true) + try { + const provider = createLLMProvider(state) + await provider.validate() + toast.success('配置可用!') + } catch (e) { + toast.error('配置不可用: ' + e.message) + } finally { + setValidating(false) + } + } + + return ( + + + className="col-span-2" + data-testid="ai-provider" + label="技术提供" + value={state.provider} + onChange={e => state.provider = e} + options={[ + { label: 'Cloudflare AI', value: 'cloudflare' }, + { label: '公共服务器', value: 'worker' }, + { label: 'Chrome 浏览器内置 AI', value: 'nano' } + ]} + /> + {state.provider === 'cloudflare' && ( + + + + + + 点击此处 + 查看如何获得 Cloudflare API Token 和 Account ID + + + + + )} + {['cloudflare', 'worker'].includes(state.provider) && ( + + data-testid="ai-model" + label="模型提供" + value={state.model} + onChange={e => state.model = e} + options={[ + { label: '@cf/qwen/qwen1.5-14b-chat-awq', value: '@cf/qwen/qwen1.5-14b-chat-awq' }, + { label: '@cf/qwen/qwen1.5-7b-chat-awq', value: '@cf/qwen/qwen1.5-7b-chat-awq' }, + { label: '@cf/qwen/qwen1.5-1.8b-chat', value: '@cf/qwen/qwen1.5-1.8b-chat' }, + { label: '@hf/google/gemma-7b-it', value: '@hf/google/gemma-7b-it' }, + { label: '@hf/nousresearch/hermes-2-pro-mistral-7b', value: '@hf/nousresearch/hermes-2-pro-mistral-7b' } + ]} + /> + )} +
+ +
+
+ ) +} + +export default LLMSettings \ No newline at end of file diff --git a/src/tabs/summarizer.tsx b/src/tabs/summarizer.tsx index 06a0df7a..9bccd0bd 100644 --- a/src/tabs/summarizer.tsx +++ b/src/tabs/summarizer.tsx @@ -1,28 +1,18 @@ +import '~style.css'; + import { Typography } from "@material-tailwind/react"; import icon from 'raw:assets/icon.png'; import { useCallback, useDeferredValue, useEffect, useState } from "react"; import { useForwarder } from "~hooks/forwarder"; import ChatBubble from "~components/ChatBubble"; -import createLLMProvider, { type LLMProviders } from "~llms"; -import type { AISchema } from "~options/features/jimaku/components/AIFragment"; -import '~style.css'; +import createLLMProvider from "~llms"; import { getSettingStorage } from "~utils/storage"; const urlParams = new URLSearchParams(window.location.search); const roomId = urlParams.get('roomId') const roomTitle = urlParams.get('title') -function createLLM(schema: AISchema): LLMProviders { - switch (schema.provider) { - case 'worker': - case 'nano': - return createLLMProvider(schema.provider) - case 'qwen': - return createLLMProvider(schema.provider, schema.accountId, schema.apiToken) - } -} - const loadingText = '正在加载同传字幕总结.....' function App() { @@ -52,8 +42,8 @@ function App() { const summarize = useCallback(async (danmakus: string[]) => { try { - const { jimaku: { aiZone } } = await getSettingStorage('settings.features') - const llm = createLLM(aiZone) + const llmSettings = await getSettingStorage('settings.llm') + const llm = createLLMProvider(llmSettings) const summaryStream = llm.promptStream(`这位是一名在b站直播间直播的日本vtuber说过的话,请根据下文对话猜测与观众的互动内容,并用中文总结一下他们的对话:\n\n${danmakus.join('\n')}`) setLoading(false) for await (const words of summaryStream) { diff --git a/src/types/cloudflare/index.ts b/src/types/cloudflare/index.ts index 34351d67..a0646600 100644 --- a/src/types/cloudflare/index.ts +++ b/src/types/cloudflare/index.ts @@ -3,6 +3,6 @@ export * from './workers-ai' export type Result = { success: boolean result: T - errors: string[] + errors: { code: number, message: string}[] messages: string[] } \ No newline at end of file diff --git a/tests/features/jimaku.spec.ts b/tests/features/jimaku.spec.ts index ef1b8060..1bb0794c 100644 --- a/tests/features/jimaku.spec.ts +++ b/tests/features/jimaku.spec.ts @@ -207,7 +207,7 @@ test('测试同传字幕AI总结', { tag: "@scoped" }, async ({ room, content: p logger.debug('AI Summary:', res) const maybe = expect.configure({ soft: true }) - maybe(res).toMatch(/主播|日本VTuber/) + maybe(res).toMatch(/主播|日本VTuber|日本vtuber|vtuber/) maybe(res).toMatch(/直播|观众/) maybe(res).toContain('麦当劳') maybe(res).toContain('漫展') diff --git a/tests/integrations/summarizer.spec.ts b/tests/integrations/summarizer.spec.ts new file mode 100644 index 00000000..b2af11e0 --- /dev/null +++ b/tests/integrations/summarizer.spec.ts @@ -0,0 +1,70 @@ +import { expect, test } from "@tests/fixtures/component"; +import logger from "@tests/helpers/logger"; +import createLLMProvider, { type LLMTypes } from "~llms"; + +const prompt = `这位是一名在b站直播间直播的日本vtuber说过的话,请根据下文对话猜测与观众的互动内容,并用中文总结一下他们的对话:\n\n${[ + '大家好', + '早上好', + '知道我今天吃了什么吗?', + '是麦当劳哦!', + '"不就个麦当劳而已吗"不是啦', + '是最近那个很热门的新品', + '对,就是那个', + '然后呢, 今天久违的出门了', + '对,平时都是宅在家里的呢', + '"终于长大了"喂w', + '然后今天去了漫展来着', + '很多人呢', + '之前的我看到那么多人肯定社恐了', + '但今次意外的没有呢', + '"果然是长大了"也是呢', + '然后呢, 今天买了很多东西', + '插画啊,手办啊,周边之类的', + '荷包大出血w', + '不过觉得花上去应该值得的...吧?', + '喂,好过分啊', + '不过确实不应该花那么多钱的', + '然后呢,回家途中看到了蟑螂的尸体', + '太恶心了', + '然后把我一整天好心情搞没了w', + '"就因为一个蟑螂"对www', + '不过跟你们谈完反而心情好多了', + '谢谢大家', + '那么今天的杂谈就到这里吧', + '下次再见啦', + '拜拜~' +].join('\n')}` as const + +function testModel(model: string, provider: LLMTypes = 'worker') { + return async function () { + + logger.info(`正在测试模型 ${model} ...`) + + const llm = createLLMProvider({ + provider, + model + }) + + const res = await llm.prompt(prompt) + logger.info(`模型 ${model} 的总结结果`, res) + + const maybe = expect.configure({ soft: true }) + maybe(res).toMatch(/主播|日本VTuber|日本vtuber|vtuber/) + maybe(res).toMatch(/直播|观众/) + maybe(res).toContain('麦当劳') + maybe(res).toContain('漫展') + maybe(res).toContain('蟑螂') + } +} + +test.slow() + +test('测试 @cf/qwen/qwen1.5-14b-chat-awq 模型的AI总结结果', { tag: "@scoped" }, testModel('@cf/qwen/qwen1.5-14b-chat-awq')) + +test('测试 @cf/qwen/qwen1.5-7b-chat-awq 模型的AI总结结果', { tag: "@scoped" },testModel('@cf/qwen/qwen1.5-7b-chat-awq')) + +test('测试 @cf/qwen/qwen1.5-1.8b-chat 模型的AI总结结果', { tag: "@scoped" },testModel('@cf/qwen/qwen1.5-1.8b-chat')) + +test('测试 @hf/google/gemma-7b-it 模型的AI总结结果', { tag: "@scoped" }, testModel('@hf/google/gemma-7b-it')) + +test('测试 @hf/nousresearch/hermes-2-pro-mistral-7b 模型的AI总结结果', { tag: "@scoped" }, testModel('@hf/nousresearch/hermes-2-pro-mistral-7b')) \ No newline at end of file diff --git a/tests/units/llm.spec.ts b/tests/units/llm.spec.ts index 6fd13cc9..67e6dd9e 100644 --- a/tests/units/llm.spec.ts +++ b/tests/units/llm.spec.ts @@ -7,7 +7,11 @@ test('嘗試使用 Cloudflare AI 對話', { tag: "@scoped" }, async () => { test.skip(!process.env.CF_ACCOUNT_ID || !process.env.CF_API_TOKEN, '請設定 CF_ACCOUNT_ID 和 CF_API_TOKEN 環境變數') - const llm = createLLMProvider('qwen', process.env.CF_ACCOUNT_ID, process.env.CF_API_TOKEN) + const llm = createLLMProvider({ + provider: 'cloudflare', + accountId: process.env.CF_ACCOUNT_ID, + apiToken: process.env.CF_API_TOKEN + }) logger.info('正在测试 json 返回请求...') const res = await llm.prompt('你好') @@ -44,7 +48,7 @@ test('嘗試使用 Gemini Nano 對話', { tag: "@scoped" }, async ({ page, modul const ret = await page.evaluate(async () => { const { llms } = window as any console.log('llms: ', llms) - const llm = await llms.createLLMProvider('nano') + const llm = await llms.createLLMProvider({ provider: 'nano' }) return await llm.prompt('你好') }) @@ -54,7 +58,7 @@ test('嘗試使用 Gemini Nano 對話', { tag: "@scoped" }, async ({ page, modul test('嘗試使用 Remote Worker 對話', { tag: "@scoped" }, async () => { - const llm = createLLMProvider('worker') + const llm = createLLMProvider({ provider: 'worker' }) logger.info('正在测试 json 返回请求...') const res = await llm.prompt('你好') From 86efdab06badc92f7c594673da939631092e5b59 Mon Sep 17 00:00:00 2001 From: eric2788 Date: Wed, 23 Oct 2024 11:09:45 +0800 Subject: [PATCH 11/15] optimized test cases for ai feature --- tests/features/jimaku.spec.ts | 2 +- tests/integrations/summarizer.spec.ts | 18 +++++++++++------- tests/units/llm.spec.ts | 2 +- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/features/jimaku.spec.ts b/tests/features/jimaku.spec.ts index 1bb0794c..7a653022 100644 --- a/tests/features/jimaku.spec.ts +++ b/tests/features/jimaku.spec.ts @@ -203,7 +203,7 @@ test('测试同传字幕AI总结', { tag: "@scoped" }, async ({ room, content: p logger.info('正在測試AI总結結果... (15s)') await summarizer.waitForTimeout(15000) await expect(summarizer.getByText('错误')).toBeHidden({ timeout: 5000 }) - const res = await summarizer.getByTestId('同传字幕总结-bubble-chat-0').locator('h5.leading-snug').textContent() + const res = await summarizer.getByTestId('同传字幕总结-bubble-chat-0').locator('div.leading-snug').textContent() logger.debug('AI Summary:', res) const maybe = expect.configure({ soft: true }) diff --git a/tests/integrations/summarizer.spec.ts b/tests/integrations/summarizer.spec.ts index b2af11e0..35ab987f 100644 --- a/tests/integrations/summarizer.spec.ts +++ b/tests/integrations/summarizer.spec.ts @@ -35,7 +35,7 @@ const prompt = `这位是一名在b站直播间直播的日本vtuber说过的话 '拜拜~' ].join('\n')}` as const -function testModel(model: string, provider: LLMTypes = 'worker') { +function testModel(model: string, { trash = false, provider = 'worker' }: { trash?: boolean, provider?: LLMTypes } = {}) { return async function () { logger.info(`正在测试模型 ${model} ...`) @@ -51,9 +51,12 @@ function testModel(model: string, provider: LLMTypes = 'worker') { const maybe = expect.configure({ soft: true }) maybe(res).toMatch(/主播|日本VTuber|日本vtuber|vtuber/) maybe(res).toMatch(/直播|观众/) - maybe(res).toContain('麦当劳') - maybe(res).toContain('漫展') - maybe(res).toContain('蟑螂') + + if (!trash) { + maybe(res).toContain('麦当劳') + maybe(res).toContain('漫展') + maybe(res).toContain('蟑螂') + } } } @@ -61,10 +64,11 @@ test.slow() test('测试 @cf/qwen/qwen1.5-14b-chat-awq 模型的AI总结结果', { tag: "@scoped" }, testModel('@cf/qwen/qwen1.5-14b-chat-awq')) -test('测试 @cf/qwen/qwen1.5-7b-chat-awq 模型的AI总结结果', { tag: "@scoped" },testModel('@cf/qwen/qwen1.5-7b-chat-awq')) +test('测试 @cf/qwen/qwen1.5-7b-chat-awq 模型的AI总结结果', { tag: "@scoped" }, testModel('@cf/qwen/qwen1.5-7b-chat-awq')) -test('测试 @cf/qwen/qwen1.5-1.8b-chat 模型的AI总结结果', { tag: "@scoped" },testModel('@cf/qwen/qwen1.5-1.8b-chat')) +test('测试 @cf/qwen/qwen1.5-1.8b-chat 模型的AI总结结果', { tag: "@scoped" }, testModel('@cf/qwen/qwen1.5-1.8b-chat')) -test('测试 @hf/google/gemma-7b-it 模型的AI总结结果', { tag: "@scoped" }, testModel('@hf/google/gemma-7b-it')) +// this model is too trash that cannot have any keywords +test('测试 @hf/google/gemma-7b-it 模型的AI总结结果', { tag: "@scoped" }, testModel('@hf/google/gemma-7b-it', { trash: true })) test('测试 @hf/nousresearch/hermes-2-pro-mistral-7b 模型的AI总结结果', { tag: "@scoped" }, testModel('@hf/nousresearch/hermes-2-pro-mistral-7b')) \ No newline at end of file diff --git a/tests/units/llm.spec.ts b/tests/units/llm.spec.ts index 67e6dd9e..0aa13846 100644 --- a/tests/units/llm.spec.ts +++ b/tests/units/llm.spec.ts @@ -48,7 +48,7 @@ test('嘗試使用 Gemini Nano 對話', { tag: "@scoped" }, async ({ page, modul const ret = await page.evaluate(async () => { const { llms } = window as any console.log('llms: ', llms) - const llm = await llms.createLLMProvider({ provider: 'nano' }) + const llm = llms.createLLMProvider({ provider: 'nano' }) return await llm.prompt('你好') }) From 48d015a73d768cefa4057c37d4ab688018e573e4 Mon Sep 17 00:00:00 2001 From: eric2788 Date: Wed, 23 Oct 2024 15:21:25 +0800 Subject: [PATCH 12/15] convert ai response with markdown convert --- package.json | 1 + pnpm-lock.yaml | 35 ++++++++++------------ src/components/ChatBubble.tsx | 2 +- src/hooks/{input.ts => form.ts} | 4 +-- src/llms/cloudflare-ai.ts | 2 ++ src/llms/gemini-nano.ts | 37 +++++++++++++++++------- src/llms/index.ts | 3 +- src/llms/remote-worker.ts | 4 ++- src/options/fragments/llm.tsx | 51 +++++++++++++++++++++------------ src/options/index.tsx | 2 +- src/style.css | 10 +++++-- src/tabs/summarizer.tsx | 19 ++++++++---- tests/units/llm.spec.ts | 14 +++++++++ 13 files changed, 120 insertions(+), 64 deletions(-) rename src/hooks/{input.ts => form.ts} (93%) diff --git a/package.json b/package.json index 3fe907c8..f766ba57 100644 --- a/package.json +++ b/package.json @@ -36,6 +36,7 @@ "dexie-react-hooks": "^1.1.7", "hash-wasm": "^4.11.0", "hls.js": "^1.5.8", + "markdown-to-jsx": "^7.5.0", "media-chrome": "^2.2.5", "mpegts.js": "^1.7.3", "n-danmaku": "^2.2.1", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 84cbe4cd..a98c6037 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -56,6 +56,9 @@ importers: hls.js: specifier: ^1.5.8 version: 1.5.8 + markdown-to-jsx: + specifier: ^7.5.0 + version: 7.5.0(react@18.2.0) media-chrome: specifier: ^2.2.5 version: 2.2.5 @@ -2285,15 +2288,6 @@ packages: csstype@3.1.3: resolution: {integrity: sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==} - debug@4.3.4: - resolution: {integrity: sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==} - engines: {node: '>=6.0'} - peerDependencies: - supports-color: '*' - peerDependenciesMeta: - supports-color: - optional: true - debug@4.3.7: resolution: {integrity: sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==} engines: {node: '>=6.0'} @@ -3075,6 +3069,12 @@ packages: resolution: {integrity: sha512-LS9X+dc8KLxXCb8dni79fLIIUA5VyZoyjSMCwTluaXA0o27cCK0bhXkpgw+sTXVpPy/lSO57ilRixqk0vDmtRA==} engines: {node: '>=6'} + markdown-to-jsx@7.5.0: + resolution: {integrity: sha512-RrBNcMHiFPcz/iqIj0n3wclzHXjwS7mzjBNWecKKVhNTIxQepIix6Il/wZCn2Cg5Y1ow2Qi84+eJrryFRWBEWw==} + engines: {node: '>= 10'} + peerDependencies: + react: '>= 0.14.0' + material-ripple-effects@2.0.1: resolution: {integrity: sha512-hHlUkZAuXbP94lu02VgrPidbZ3hBtgXBtjlwR8APNqOIgDZMV8MCIcsclL8FmGJQHvnORyvoQgC965vPsiyXLQ==} @@ -3144,9 +3144,6 @@ packages: mpegts.js@1.7.3: resolution: {integrity: sha512-kqZ1C1IsbAQN72cK8vMrzKeM7hwrwSBbFAwVAc7PPweOeoZxCANrc7fAVDKMfYUzxdNkMTnec9tVmlxmKZB0TQ==} - ms@2.1.2: - resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} - ms@2.1.3: resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==} @@ -4065,7 +4062,7 @@ snapshots: '@babel/traverse': 7.24.1 '@babel/types': 7.24.0 convert-source-map: 2.0.0 - debug: 4.3.4 + debug: 4.3.7 gensync: 1.0.0-beta.2 json5: 2.2.3 semver: 6.3.1 @@ -4255,7 +4252,7 @@ snapshots: '@babel/helper-split-export-declaration': 7.22.6 '@babel/parser': 7.24.1 '@babel/types': 7.24.0 - debug: 4.3.4 + debug: 4.3.7 globals: 11.12.0 transitivePeerDependencies: - supports-color @@ -6490,10 +6487,6 @@ snapshots: csstype@3.1.3: {} - debug@4.3.4: - dependencies: - ms: 2.1.2 - debug@4.3.7: dependencies: ms: 2.1.3 @@ -7236,6 +7229,10 @@ snapshots: semver: 5.7.2 optional: true + markdown-to-jsx@7.5.0(react@18.2.0): + dependencies: + react: 18.2.0 + material-ripple-effects@2.0.1: {} mdn-data@2.0.14: {} @@ -7286,8 +7283,6 @@ snapshots: es6-promise: 4.2.8 webworkify-webpack: 2.1.5 - ms@2.1.2: {} - ms@2.1.3: {} msgpackr-extract@3.0.3: diff --git a/src/components/ChatBubble.tsx b/src/components/ChatBubble.tsx index fc57bb86..d6898804 100644 --- a/src/components/ChatBubble.tsx +++ b/src/components/ChatBubble.tsx @@ -27,7 +27,7 @@ function ChatBubble(props: ChatBubbleProps): JSX.Element { {messages.map((message, index) => (
-
{message.text}
+
{message.text}
{message.time && (
diff --git a/src/hooks/input.ts b/src/hooks/form.ts similarity index 93% rename from src/hooks/input.ts rename to src/hooks/form.ts index 964ee2ac..4a5595f7 100644 --- a/src/hooks/input.ts +++ b/src/hooks/form.ts @@ -1,9 +1,9 @@ import { useCallback, useRef } from "react"; -export function useFileInput(onFileChange: (files: FileList) => Promise, onError?: (e: Error | any) => void, deps: any[] = []){ +export function useFileInput(onFileChange: (files: FileList) => Promise, onError?: (e: Error | any) => void, deps: any[] = []) { const inputRef = useRef() - const selectFiles = useCallback(function(): Promise { + const selectFiles = useCallback(function (): Promise { return new Promise((resolve, reject) => { const listener = async (e: Event) => { try { diff --git a/src/llms/cloudflare-ai.ts b/src/llms/cloudflare-ai.ts index c122c394..260c855c 100644 --- a/src/llms/cloudflare-ai.ts +++ b/src/llms/cloudflare-ai.ts @@ -19,6 +19,8 @@ export default class CloudFlareAI implements LLMProviders { this.model = settings.model || CloudFlareAI.DEFAULT_MODEL } + cumulative: boolean = true + async validate(): Promise { const success = await validateAIToken(this.accountId, this.apiToken, this.model) if (typeof success === 'boolean' && !success) throw new Error('Cloudflare API 验证失败') diff --git a/src/llms/gemini-nano.ts b/src/llms/gemini-nano.ts index 8d2e79a0..f2da1f4e 100644 --- a/src/llms/gemini-nano.ts +++ b/src/llms/gemini-nano.ts @@ -2,6 +2,8 @@ import type { LLMProviders, Session } from "~llms" export default class GeminiNano implements LLMProviders { + cumulative: boolean = false + async validate(): Promise { if (!window.ai) throw new Error('你的浏览器没有启用 AI 功能') if (!window.ai.languageModel && @@ -13,8 +15,10 @@ export default class GeminiNano implements LLMProviders { async prompt(chat: string): Promise { const session = await this.asSession() try { + console.debug('[gemini nano] prompting: ', chat) return session.prompt(chat) } finally { + console.debug('[gemini nano] done') session[Symbol.dispose]() } } @@ -22,34 +26,41 @@ export default class GeminiNano implements LLMProviders { async *promptStream(chat: string): AsyncGenerator { const session = await this.asSession() try { - return session.promptStream(chat) + console.debug('[gemini nano] prompting stream: ', chat) + const res = session.promptStream(chat) + for await (const chunk of res) { + yield chunk + } } finally { + console.debug('[gemini nano] done') session[Symbol.dispose]() } } async asSession(): Promise> { - if (window.ai.assistant || window.ai.languageModel) { - const assistant = window.ai.assistant ?? window.ai.languageModel - const capabilities = await assistant.capabilities() - if (capabilities.available === 'readily') { - return new GeminiAssistant(await assistant.create()) - } else { - console.warn('AI Assistant 当前不可用: ', capabilities) - } - } - if (window.ai.summarizer) { const summarizer = window.ai.summarizer const capabilities = await summarizer.capabilities() if (capabilities.available === 'readily') { + console.debug('using gemini summarizer') return new GeminiSummarizer(await summarizer.create()) } else { console.warn('AI Summarizer 当前不可用: ', capabilities) } } + if (window.ai.assistant || window.ai.languageModel) { + const assistant = window.ai.assistant ?? window.ai.languageModel + const capabilities = await assistant.capabilities() + if (capabilities.available === 'readily') { + console.debug('using gemini assistant') + return new GeminiAssistant(await assistant.create()) + } else { + console.warn('AI Assistant 当前不可用: ', capabilities) + } + } + throw new Error('你的浏览器 AI 功能当前不可用') } } @@ -59,10 +70,12 @@ class GeminiAssistant implements Session { constructor(private readonly assistant: AIAssistant) { } prompt(chat: string): Promise { + console.debug('[assistant] prompting: ', chat) return this.assistant.prompt(chat) } async *promptStream(chat: string): AsyncGenerator { + console.debug('[assistant] prompting stream: ', chat) const stream = this.assistant.promptStreaming(chat) for await (const chunk of stream) { yield chunk @@ -80,10 +93,12 @@ class GeminiSummarizer implements Session { constructor(private readonly summarizer: AISummarizer) { } prompt(chat: string): Promise { + console.debug('[summarizer] summarizing: ', chat) return this.summarizer.summarize(chat) } async *promptStream(chat: string): AsyncGenerator { + console.debug('[summarizer] summarizing stream: ', chat) const stream = this.summarizer.summarizeStreaming(chat) for await (const chunk of stream) { yield chunk diff --git a/src/llms/index.ts b/src/llms/index.ts index a7daee5f..15031b7a 100644 --- a/src/llms/index.ts +++ b/src/llms/index.ts @@ -5,13 +5,14 @@ import nano from './gemini-nano' import worker from './remote-worker' export interface LLMProviders { + cumulative: boolean validate(): Promise prompt(chat: string): Promise promptStream(chat: string): AsyncGenerator asSession(): Promise> } -export type Session = Disposable & Omit +export type Session = Disposable & Omit const llms = { cloudflare, diff --git a/src/llms/remote-worker.ts b/src/llms/remote-worker.ts index d2f08c78..2866014e 100644 --- a/src/llms/remote-worker.ts +++ b/src/llms/remote-worker.ts @@ -11,6 +11,8 @@ export default class RemoteWorker implements LLMProviders { this.model = settings.model || undefined } + cumulative: boolean = true + async validate(): Promise { const res = await fetch('https://llm.ericlamm.xyz/status') const json = await res.json() @@ -55,5 +57,5 @@ export default class RemoteWorker implements LLMProviders { [Symbol.dispose]: () => { } } } - + } \ No newline at end of file diff --git a/src/options/fragments/llm.tsx b/src/options/fragments/llm.tsx index 85a7261d..561f2a1d 100644 --- a/src/options/fragments/llm.tsx +++ b/src/options/fragments/llm.tsx @@ -1,5 +1,5 @@ import { Button, Input, Tooltip, Typography } from "@material-tailwind/react" -import { Fragment, useState, type ChangeEvent } from "react" +import { Fragment, useState, type ChangeEvent, type ReactNode } from "react" import { toast } from "sonner/dist" import type { StateProxy } from "~hooks/binding" import type { LLMTypes } from "~llms" @@ -29,6 +29,29 @@ export const description = [ '技术提供默认为公共的服务器,质量可能不稳定,建议设置为 Cloudflare 作为技术提供来源。' ] + +function Hints({ children }: { children: ReactNode }): JSX.Element { + return ( + + + + + {children} + + ) +} + function LLMSettings({ state, useHandler }: StateProxy): JSX.Element { const [validating, setValidating] = useState(false) @@ -63,24 +86,10 @@ function LLMSettings({ state, useHandler }: StateProxy): JSX.Elem /> {state.provider === 'cloudflare' && ( - - - - + 点击此处 查看如何获得 Cloudflare API Token 和 Account ID - + ): JSX.Elem { label: '@cf/qwen/qwen1.5-1.8b-chat', value: '@cf/qwen/qwen1.5-1.8b-chat' }, { label: '@hf/google/gemma-7b-it', value: '@hf/google/gemma-7b-it' }, { label: '@hf/nousresearch/hermes-2-pro-mistral-7b', value: '@hf/nousresearch/hermes-2-pro-mistral-7b' } - ]} + ]} /> )} + {state.provider === 'nano' && ( + + 点击此处 + 查看如何启用 Chrome 浏览器内置 AI + + )}
diff --git a/tests/units/llm.spec.ts b/tests/units/llm.spec.ts index 0aa13846..c30629d2 100644 --- a/tests/units/llm.spec.ts +++ b/tests/units/llm.spec.ts @@ -54,6 +54,20 @@ test('嘗試使用 Gemini Nano 對話', { tag: "@scoped" }, async ({ page, modul logger.info('response: ', ret) await expect(ret).not.toBeEmpty() + + const ret2 = await page.evaluate(async () => { + const { llms } = window as any + const llm = llms.createLLMProvider({ provider: 'nano' }) + const res = llm.promptStream('地球为什么是圆的?') + let msg = ''; + for await (const r of res) { + console.log('response: ', r) + msg = r + } + return msg + }) + + logger.info('stream response: ', ret2) }) test('嘗試使用 Remote Worker 對話', { tag: "@scoped" }, async () => { From 35a32dc06fc0c97bd52a695dd2eefaf11e9862bb Mon Sep 17 00:00:00 2001 From: eric2788 Date: Wed, 23 Oct 2024 18:12:16 +0800 Subject: [PATCH 13/15] added web-llm to llm provider --- package.json | 2 + pnpm-lock.yaml | 24 ++++++++ src/contents/index/mounter.tsx | 2 +- src/hooks/life-cycle.ts | 30 ++++++++++ src/llms/cloudflare-ai.ts | 2 +- src/llms/gemini-nano.ts | 8 +-- src/llms/index.ts | 8 ++- src/llms/models.ts | 33 +++++++++++ src/llms/remote-worker.ts | 2 +- src/llms/web-llm.ts | 87 +++++++++++++++++++++++++++++ src/options/components/Selector.tsx | 3 +- src/options/fragments/llm.tsx | 82 +++++++++++++++++---------- tests/units/llm.spec.ts | 64 +++++++++++++++++++++ tsconfig.json | 3 + 14 files changed, 311 insertions(+), 39 deletions(-) create mode 100644 src/llms/models.ts create mode 100644 src/llms/web-llm.ts diff --git a/package.json b/package.json index f766ba57..90d49282 100644 --- a/package.json +++ b/package.json @@ -27,6 +27,7 @@ "@ffmpeg/ffmpeg": "^0.12.10", "@ffmpeg/util": "^0.12.1", "@material-tailwind/react": "^2.1.9", + "@mlc-ai/web-llm": "^0.2.73", "@plasmohq/messaging": "^0.6.2", "@plasmohq/storage": "^1.9.3", "@react-hooks-library/core": "^0.5.2", @@ -64,6 +65,7 @@ "@types/react": "18.2.37", "@types/react-dom": "18.2.15", "@types/semver": "^7.5.8", + "@webgpu/types": "^0.1.49", "dotenv": "^16.4.5", "esbuild": "^0.20.2", "gify-parse": "^1.0.7", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a98c6037..caa0b650 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -29,6 +29,9 @@ importers: '@material-tailwind/react': specifier: ^2.1.9 version: 2.1.9(react-dom@18.2.0(react@18.2.0))(react@18.2.0) + '@mlc-ai/web-llm': + specifier: ^0.2.73 + version: 0.2.73 '@plasmohq/messaging': specifier: ^0.6.2 version: 0.6.2(react@18.2.0) @@ -135,6 +138,9 @@ importers: '@types/semver': specifier: ^7.5.8 version: 7.5.8 + '@webgpu/types': + specifier: ^0.1.49 + version: 0.1.49 dotenv: specifier: ^16.4.5 version: 16.4.5 @@ -856,6 +862,9 @@ packages: resolution: {integrity: sha512-iA7+tyVqfrATAIsIRWQG+a7ZLLD0VaOCKV2Wd/v4mqIU3J9c4jx9p7S0nw1XH3gJCKNBOOwACOPYYSUu9pgT+w==} engines: {node: '>=12.0.0'} + '@mlc-ai/web-llm@0.2.73': + resolution: {integrity: sha512-v9eS05aptZYNBhRZ7KEq5dr1kMESHvkdwru9tNY1oS1GDDEBjUt4Y/u06ejvP3qF4VatCakEezzJT3SkuMhDnQ==} + '@motionone/animation@10.17.0': resolution: {integrity: sha512-ANfIN9+iq1kGgsZxs+Nz96uiNcPLGTXwfNo2Xz/fcJXniPYpaz/Uyrfa+7I5BPLxCP82sh7quVDudf1GABqHbg==} @@ -1946,6 +1955,9 @@ packages: '@vue/shared@3.3.4': resolution: {integrity: sha512-7OjdcV8vQ74eiz1TZLzZP4JwqM5fA94K6yntPS5Z25r9HDuGNzaGdgvwKYq6S+MxwF0TFRwe50fIR/MYnakdkQ==} + '@webgpu/types@0.1.49': + resolution: {integrity: sha512-NMmS8/DofhH/IFeW+876XrHVWel+J/vdcFCHLDqeJgkH9x0DeiwjVd8LcBdaxdG/T7Rf8VUAYsA8X1efMzLjRQ==} + abortcontroller-polyfill@1.7.5: resolution: {integrity: sha512-JMJ5soJWP18htbbxJjG7bG6yuI6pRhgJ0scHHTfkUjf6wjP912xZWvM+A4sJK3gqd9E8fcPbDnOefbA9Th/FIQ==} @@ -3043,6 +3055,10 @@ packages: resolution: {integrity: sha512-8XPvpAA8uyhfteu8pIvQxpJZ7SYYdpUivZpGy6sFsBuKRY/7rQGavedeB8aK+Zkyq6upMFVL/9AW6vOYzfRyLg==} engines: {node: '>=10'} + loglevel@1.9.2: + resolution: {integrity: sha512-HgMmCqIJSAKqo68l0rS2AanEWfkxaZ5wNiEFb5ggm08lDs9Xl2KxBlX3PTcaD2chBM1gXAYf491/M2Rv8Jwayg==} + engines: {node: '>= 0.6.0'} + loose-envify@1.4.0: resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} hasBin: true @@ -4673,6 +4689,10 @@ snapshots: '@lezer/lr': 1.4.0 json5: 2.2.3 + '@mlc-ai/web-llm@0.2.73': + dependencies: + loglevel: 1.9.2 + '@motionone/animation@10.17.0': dependencies: '@motionone/easing': 10.17.0 @@ -6145,6 +6165,8 @@ snapshots: '@vue/shared@3.3.4': {} + '@webgpu/types@0.1.49': {} + abortcontroller-polyfill@1.7.5: {} acorn@8.12.1: {} @@ -7203,6 +7225,8 @@ snapshots: chalk: 4.1.2 is-unicode-supported: 0.1.0 + loglevel@1.9.2: {} + loose-envify@1.4.0: dependencies: js-tokens: 4.0.0 diff --git a/src/contents/index/mounter.tsx b/src/contents/index/mounter.tsx index 5f6bbf59..fd097533 100644 --- a/src/contents/index/mounter.tsx +++ b/src/contents/index/mounter.tsx @@ -173,7 +173,7 @@ function createApp(roomId: string, plasmo: PlasmoSpec, info: StreamInfo): App { onClick: () => this.start() }, position: 'top-left', - duration: 6000000, + duration: Infinity, dismissible: false }) return diff --git a/src/hooks/life-cycle.ts b/src/hooks/life-cycle.ts index 79d8c242..13fabace 100644 --- a/src/hooks/life-cycle.ts +++ b/src/hooks/life-cycle.ts @@ -74,4 +74,34 @@ export function useTimeoutElement(before: JSX.Element, after: JSX.Element, timeo }, [after, before, timeout]); return element; +} + +/** + * A React hook that ensures the provided disposable resource is properly disposed of + * when the component is unmounted. + * + * @param disposable - The resource to be disposed of, which can be either a `Disposable` + * or an `AsyncDisposable`. The resource must implement either the `Symbol.dispose` or + * `Symbol.asyncDispose` method. + * + * @example + * ```typescript + * const MyComponent = () => { + * const disposableResource = useMemo(() => createDisposableResource(), []); + * withDisposable(disposableResource); + * + * return
My Component
; + * }; + * ``` + */ +export function useDisposable(disposable: Disposable | AsyncDisposable) { + useEffect(() => { + return () => { + if (!!disposable[Symbol.dispose]) { + disposable[Symbol.dispose](); + } else if (!!disposable[Symbol.asyncDispose]) { + disposable[Symbol.asyncDispose](); + } + } + }, []) } \ No newline at end of file diff --git a/src/llms/cloudflare-ai.ts b/src/llms/cloudflare-ai.ts index 260c855c..44acd030 100644 --- a/src/llms/cloudflare-ai.ts +++ b/src/llms/cloudflare-ai.ts @@ -41,7 +41,7 @@ export default class CloudFlareAI implements LLMProviders { console.warn('Cloudflare AI session is not supported') return { ...this, - [Symbol.dispose]: () => { } + [Symbol.asyncDispose]: async () => { } } } diff --git a/src/llms/gemini-nano.ts b/src/llms/gemini-nano.ts index f2da1f4e..d9dd23a6 100644 --- a/src/llms/gemini-nano.ts +++ b/src/llms/gemini-nano.ts @@ -19,7 +19,7 @@ export default class GeminiNano implements LLMProviders { return session.prompt(chat) } finally { console.debug('[gemini nano] done') - session[Symbol.dispose]() + await session[Symbol.asyncDispose]() } } @@ -33,7 +33,7 @@ export default class GeminiNano implements LLMProviders { } } finally { console.debug('[gemini nano] done') - session[Symbol.dispose]() + await session[Symbol.asyncDispose]() } } @@ -82,7 +82,7 @@ class GeminiAssistant implements Session { } } - [Symbol.dispose](): void { + async [Symbol.asyncDispose]() { this.assistant.destroy() } } @@ -105,7 +105,7 @@ class GeminiSummarizer implements Session { } } - [Symbol.dispose](): void { + async [Symbol.asyncDispose]() { this.summarizer.destroy() } diff --git a/src/llms/index.ts b/src/llms/index.ts index 15031b7a..c57b1212 100644 --- a/src/llms/index.ts +++ b/src/llms/index.ts @@ -3,21 +3,23 @@ import type { SettingSchema as LLMSchema } from '~options/fragments/llm' import cloudflare from './cloudflare-ai' import nano from './gemini-nano' import worker from './remote-worker' +import webllm from './web-llm' export interface LLMProviders { cumulative: boolean - validate(): Promise + validate(progress?: (p: number, t: string) => void): Promise prompt(chat: string): Promise promptStream(chat: string): AsyncGenerator asSession(): Promise> } -export type Session = Disposable & Omit +export type Session = AsyncDisposable & Omit const llms = { cloudflare, nano, - worker + worker, + webllm } export type LLMs = typeof llms diff --git a/src/llms/models.ts b/src/llms/models.ts new file mode 100644 index 00000000..23f2b9f1 --- /dev/null +++ b/src/llms/models.ts @@ -0,0 +1,33 @@ +import type { LLMTypes } from "~llms" + +export type ModelList = { + providers: LLMTypes[] + models: string[] +} + +const models: ModelList[] = [ + { + providers: ['worker', 'cloudflare'], + models: [ + '@cf/qwen/qwen1.5-14b-chat-awq', + '@cf/qwen/qwen1.5-7b-chat-awq', + '@cf/qwen/qwen1.5-1.8b-chat', + '@hf/google/gemma-7b-it', + '@hf/nousresearch/hermes-2-pro-mistral-7b' + ] + }, + { + providers: [ 'webllm' ], + models: [ + 'Qwen2-7B-Instruct-q4f32_1-MLC', + 'Qwen2.5-14B-Instruct-q4f16_1-MLC', + 'gemma-2-9b-it-q4f16_1-MLC', + 'Qwen2.5-3B-Instruct-q0f16-MLC', + 'Phi-3-mini-128k-instruct-q0f16-MLC', + 'Phi-3.5-mini-instruct-q4f16_1-MLC-1k' + ] + } +] + + +export default models \ No newline at end of file diff --git a/src/llms/remote-worker.ts b/src/llms/remote-worker.ts index 2866014e..6c859cfe 100644 --- a/src/llms/remote-worker.ts +++ b/src/llms/remote-worker.ts @@ -54,7 +54,7 @@ export default class RemoteWorker implements LLMProviders { console.warn('Remote worker session is not supported') return { ...this, - [Symbol.dispose]: () => { } + [Symbol.asyncDispose]: async () => { } } } diff --git a/src/llms/web-llm.ts b/src/llms/web-llm.ts new file mode 100644 index 00000000..6d643859 --- /dev/null +++ b/src/llms/web-llm.ts @@ -0,0 +1,87 @@ +import type { MLCEngine } from "@mlc-ai/web-llm"; +import type { LLMProviders, Session } from "~llms"; +import type { SettingSchema } from "~options/fragments/llm"; + +export default class WebLLM implements LLMProviders { + + private static readonly DEFAULT_MODEL: string = 'Qwen2-7B-Instruct-q4f32_1-MLC' + + private readonly model: string + + constructor(settings: SettingSchema) { + this.model = settings.model || WebLLM.DEFAULT_MODEL + } + + cumulative: boolean = true + + async validate(progresser?: (p: number, t: string) => void): Promise { + await this.initializeEngine(progresser) + } + + async prompt(chat: string): Promise { + const session = await this.asSession() + try { + console.debug('[web-llm] prompting: ', chat) + return session.prompt(chat) + } finally { + console.debug('[web-llm] done') + await session[Symbol.asyncDispose]() + } + } + + async *promptStream(chat: string): AsyncGenerator { + const session = await this.asSession() + try { + console.debug('[web-llm] prompting stream: ', chat) + const res = session.promptStream(chat) + for await (const chunk of res) { + yield chunk + } + } finally { + console.debug('[web-llm] done') + await session[Symbol.asyncDispose]() + } + } + + async asSession(): Promise> { + const engine = await this.initializeEngine() + return { + async prompt(chat: string) { + await engine.interruptGenerate() + const c = await engine.completions.create({ + prompt: chat, + max_tokens: 512, + temperature: 0.2, + }) + return c.choices[0]?.text ?? engine.getMessage() + }, + async *promptStream(chat: string): AsyncGenerator { + await engine.interruptGenerate() + const chunks = await engine.completions.create({ + prompt: chat, + max_tokens: 512, + temperature: 0.2, + stream: true + }) + for await (const chunk of chunks) { + yield chunk.choices[0]?.text || ""; + if (chunk.usage) { + console.debug('Usage:', chunk.usage) + } + } + }, + [Symbol.asyncDispose]: engine.unload + } + } + + private async initializeEngine(progresser?: (p: number, t: string) => void): Promise { + const { CreateMLCEngine } = await import('@mlc-ai/web-llm') + return CreateMLCEngine(this.model, { + initProgressCallback: (progress) => { + progresser?.(progress.progress, "正在下载AI模型到本地") + console.log('初始化进度:', progress) + } + }) + } + +} \ No newline at end of file diff --git a/src/options/components/Selector.tsx b/src/options/components/Selector.tsx index cc94049a..f7f5b204 100644 --- a/src/options/components/Selector.tsx +++ b/src/options/components/Selector.tsx @@ -15,6 +15,7 @@ export type SelectorProps = { disabled?: boolean options: SelectorOption[] className?: string + emptyValue?: string } @@ -37,7 +38,7 @@ function Selector(props: SelectorProps): JSX.Element {
!props.disabled && setOpen(!isOpen)}>
- {props.options.find((option) => option.value === props.value)?.label ?? String(props.value)} + {props.options.find((option) => option.value === props.value)?.label || String(props.value || (props.emptyValue || '请选择'))} diff --git a/src/options/fragments/llm.tsx b/src/options/fragments/llm.tsx index 561f2a1d..8c5ed42a 100644 --- a/src/options/fragments/llm.tsx +++ b/src/options/fragments/llm.tsx @@ -1,9 +1,10 @@ import { Button, Input, Tooltip, Typography } from "@material-tailwind/react" -import { Fragment, useState, type ChangeEvent, type ReactNode } from "react" +import { Fragment, useEffect, useMemo, useRef, useState, type ChangeEvent, type ReactNode } from "react" import { toast } from "sonner/dist" import type { StateProxy } from "~hooks/binding" import type { LLMTypes } from "~llms" import createLLMProvider from "~llms" +import models from "~llms/models" import Selector from "~options/components/Selector" export type SettingSchema = { @@ -55,21 +56,50 @@ function Hints({ children }: { children: ReactNode }): JSX.Element { function LLMSettings({ state, useHandler }: StateProxy): JSX.Element { const [validating, setValidating] = useState(false) + const toastValidating = useRef(null) const handler = useHandler, string>((e) => e.target.value) + const selectableModels = useMemo( + () => models + .filter(({ providers }) => providers.includes(state.provider)) + .flatMap(({ models }) => models) + .map(model => ({ label: model, value: model })), + [state.provider] + ) + + const onSwitchProvider = (provider: LLMTypes) => { + state.provider = provider + state.model = undefined // reset model + if (provider === 'webllm') { + toast.info('使用 WEBLLM 时,请确保你的电脑拥有足够的算力以供 AI 运行。', { position: 'top-center' }) + } + } + const onValidate = async () => { setValidating(true) - try { - const provider = createLLMProvider(state) - await provider.validate() - toast.success('配置可用!') - } catch (e) { - toast.error('配置不可用: ' + e.message) - } finally { - setValidating(false) - } + const provider = createLLMProvider(state) + const validation = provider.validate((p, t) => { + if (toastValidating.current) { + toast.loading(`${t}... (${Math.round(p * 100)}%)`, { + id: toastValidating.current + }) + } + }) + toast.dismiss() + toastValidating.current = toast.promise(validation, { + loading: `正在验证配置...`, + success: '配置可用!', + error: err => '配置不可用: ' + (err.message ?? err), + position: 'bottom-center', + duration: Infinity, + finally: () => setValidating(false) + }) + } + console.log('provider: ', state.provider) + console.log('model: ', state.model) + return ( @@ -77,11 +107,12 @@ function LLMSettings({ state, useHandler }: StateProxy): JSX.Elem data-testid="ai-provider" label="技术提供" value={state.provider} - onChange={e => state.provider = e} + onChange={onSwitchProvider} options={[ - { label: 'Cloudflare AI', value: 'cloudflare' }, - { label: '公共服务器', value: 'worker' }, - { label: 'Chrome 浏览器内置 AI', value: 'nano' } + { label: 'Cloudflare AI (云)', value: 'cloudflare' }, + { label: '公共服务器 (云)', value: 'worker' }, + { label: 'Chrome 浏览器内置 AI (本地)', value: 'nano' }, + { label: 'Web LLM (本地)', value: 'webllm' } ]} /> {state.provider === 'cloudflare' && ( @@ -110,27 +141,22 @@ function LLMSettings({ state, useHandler }: StateProxy): JSX.Elem /> )} - {['cloudflare', 'worker'].includes(state.provider) && ( + {state.provider === 'nano' && ( + + 点击此处 + 查看如何启用 Chrome 浏览器内置 AI + + )} + {selectableModels.length > 0 && ( data-testid="ai-model" label="模型提供" value={state.model} onChange={e => state.model = e} - options={[ - { label: '@cf/qwen/qwen1.5-14b-chat-awq', value: '@cf/qwen/qwen1.5-14b-chat-awq' }, - { label: '@cf/qwen/qwen1.5-7b-chat-awq', value: '@cf/qwen/qwen1.5-7b-chat-awq' }, - { label: '@cf/qwen/qwen1.5-1.8b-chat', value: '@cf/qwen/qwen1.5-1.8b-chat' }, - { label: '@hf/google/gemma-7b-it', value: '@hf/google/gemma-7b-it' }, - { label: '@hf/nousresearch/hermes-2-pro-mistral-7b', value: '@hf/nousresearch/hermes-2-pro-mistral-7b' } - ]} + options={selectableModels} + emptyValue="默认" /> )} - {state.provider === 'nano' && ( - - 点击此处 - 查看如何启用 Chrome 浏览器内置 AI - - )}