Skip to content

Commit

Permalink
reshaped ai schema and model selectable
Browse files Browse the repository at this point in the history
  • Loading branch information
eric2788 committed Oct 22, 2024
1 parent 3f2b241 commit 4d5e611
Show file tree
Hide file tree
Showing 11 changed files with 204 additions and 149 deletions.
16 changes: 11 additions & 5 deletions src/api/cloudflare.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export async function runAI(data: any, { token, account, model }: { token: strin
return json
}

export async function *runAIStream(data: any, { token, account, model }: { token: string, account: string, model: string }): AsyncGenerator<string> {
export async function* runAIStream(data: any, { token, account, model }: { token: string, account: string, model: string }): AsyncGenerator<string> {
const res = await fetch(`${BASE_URL}/accounts/${account}/ai/run/${model}`, {
method: 'POST',
headers: {
Expand All @@ -35,12 +35,18 @@ export async function *runAIStream(data: any, { token, account, model }: { token
}
}

export async function validateAIToken(accountId: string, token: string): Promise<boolean> {
const res = await fetch(`${BASE_URL}/accounts/${accountId}/ai/models/search?per_page=1`, {
export async function validateAIToken(accountId: string, token: string, model: string): Promise<string | boolean> {
const res = await fetch(`${BASE_URL}/accounts/${accountId}/ai/models/search?search=${model}&per_page=1`, {
headers: {
Authorization: `Bearer ${this.apiToken}`
Authorization: `Bearer ${token}`
}
})
const data = await res.json() as Result<any>
return data.success
if (!data.success) {
return false
} else if (data.result.length === 0) {
return '找不到指定 AI 模型'
} else {
return true
}
}
2 changes: 1 addition & 1 deletion src/features/jimaku/components/ButtonArea.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function ButtonArea({ clearJimaku, jimakus }: ButtonAreaProps): JSX.Element {
弹出同传视窗
</JimakuButton>
}
{aiZone.enabled && (
{aiZone.summarizeEnabled && (
<JimakuButton onClick={summerize}>
同传字幕AI总结
</JimakuButton>
Expand Down
32 changes: 22 additions & 10 deletions src/llms/cf-qwen.ts → src/llms/cloudflare-ai.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,38 @@
import { runAI, runAIStream, validateAIToken } from "~api/cloudflare";
import type { LLMProviders, Session } from "~llms";
import type { SettingSchema } from "~options/fragments/llm";

export default class CloudFlareQwen implements LLMProviders {
export default class CloudFlareAI implements LLMProviders {

private static readonly MODEL: string = '@cf/qwen/qwen1.5-14b-chat-awq'
private static readonly DEFAULT_MODEL: string = '@cf/qwen/qwen1.5-14b-chat-awq'

constructor(
private readonly accountId: string,
private readonly apiToken: string,
) { }
private readonly accountId: string
private readonly apiToken: string

private readonly model: string

constructor(settings: SettingSchema) {
this.accountId = settings.accountId
this.apiToken = settings.apiToken

// only text generation model for now
this.model = settings.model || CloudFlareAI.DEFAULT_MODEL
}

async validate(): Promise<void> {
const success = await validateAIToken(this.accountId, this.apiToken)
if (!success) throw new Error('Cloudflare API 验证失败')
const success = await validateAIToken(this.accountId, this.apiToken, this.model)
if (typeof success === 'boolean' && !success) throw new Error('Cloudflare API 验证失败')
if (typeof success === 'string') throw new Error(success)
}

async prompt(chat: string): Promise<string> {
const res = await runAI(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: CloudFlareQwen.MODEL })
const res = await runAI(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: this.model })
if (!res.result) throw new Error(res.errors.join(', '))
return res.result.response
}

async *promptStream(chat: string): AsyncGenerator<string> {
return runAIStream(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: CloudFlareQwen.MODEL })
return runAIStream(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: this.model })
}

async asSession(): Promise<Session<LLMProviders>> {
Expand All @@ -33,6 +43,8 @@ export default class CloudFlareQwen implements LLMProviders {
}
}

// text generation model input schema
// so only text generation model for now
private wrap(chat: string): any {
return {
max_tokens: 512,
Expand Down
13 changes: 8 additions & 5 deletions src/llms/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import qwen from './cf-qwen'
import type { SettingSchema as LLMSchema } from '~options/fragments/llm'

import cloudflare from './cloudflare-ai'
import nano from './gemini-nano'
import worker from './remote-worker'

Expand All @@ -12,7 +14,7 @@ export interface LLMProviders {
export type Session<T> = Disposable & Omit<T, 'asSession' | 'validate'>

const llms = {
qwen,
cloudflare,
nano,
worker
}
Expand All @@ -21,9 +23,10 @@ export type LLMs = typeof llms

export type LLMTypes = keyof LLMs

function createLLMProvider<K extends LLMTypes, M extends LLMs[K]>(type: K, ...args: ConstructorParameters<M>): LLMProviders {
const LLM = llms[type].bind(this, ...args)
return new LLM()
function createLLMProvider(settings: LLMSchema): LLMProviders {
const type = settings.provider
const LLM = llms[type]
return new LLM(settings)
}

export default createLLMProvider
12 changes: 9 additions & 3 deletions src/llms/remote-worker.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import type { LLMProviders, Session } from "~llms";
import type { SettingSchema } from "~options/fragments/llm";
import { parseSSEResponses } from "~utils/binary";


// for my worker, so limited usage
export default class RemoteWorker implements LLMProviders {

private readonly model?: string

constructor(settings: SettingSchema) {
this.model = settings.model || undefined
}

async validate(): Promise<void> {
const res = await fetch('https://llm.ericlamm.xyz/status')
const json = await res.json()
Expand All @@ -19,7 +25,7 @@ export default class RemoteWorker implements LLMProviders {
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ prompt: chat })
body: JSON.stringify({ prompt: chat, model: this.model })
})
if (!res.ok) throw new Error(await res.text())
const json = await res.json()
Expand All @@ -32,7 +38,7 @@ export default class RemoteWorker implements LLMProviders {
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ prompt: chat, stream: true })
body: JSON.stringify({ prompt: chat, stream: true, model: this.model })
})
if (!res.ok) throw new Error(await res.text())
if (!res.body) throw new Error('Remote worker response body is not readable')
Expand Down
113 changes: 7 additions & 106 deletions src/options/features/jimaku/components/AIFragment.tsx
Original file line number Diff line number Diff line change
@@ -1,134 +1,35 @@
import { Button, Input, List, Tooltip, Typography } from "@material-tailwind/react"
import { type ChangeEvent, Fragment, useState } from "react"
import { toast } from "sonner/dist"
import { List } from "@material-tailwind/react"
import { type ChangeEvent, Fragment } from "react"
import type { StateProxy } from "~hooks/binding"
import type { LLMProviders, LLMTypes } from "~llms"
import createLLMProvider from "~llms"
import ExperienmentFeatureIcon from "~options/components/ExperientmentFeatureIcon"
import Selector from "~options/components/Selector"
import SwitchListItem from "~options/components/SwitchListItem"



export type AISchema = {
enabled: boolean
provider: LLMTypes

// cloudflare settings
accountId?: string
apiToken?: string
summarizeEnabled: boolean
}


export const aiDefaultSettings: Readonly<AISchema> = {
enabled: false,
provider: 'worker'
summarizeEnabled: false
}


function AIFragment({ state, useHandler }: StateProxy<AISchema>): JSX.Element {

const [validating, setValidating] = useState(false)

const handler = useHandler<ChangeEvent<HTMLInputElement>, string>((e) => e.target.value)
const checker = useHandler<ChangeEvent<HTMLInputElement>, boolean>((e) => e.target.checked)

const onValidate = async () => {
setValidating(true)
try {
let provider: LLMProviders;
if (state.provider === 'qwen') {
provider = createLLMProvider(state.provider, state.accountId, state.apiToken)
} else {
provider = createLLMProvider(state.provider)
}
await provider.validate()
toast.success('配置可用!')
} catch (e) {
toast.error('配置不可用: ' + e.message)
} finally {
setValidating(false)
}
}

return (
<Fragment>
<List className="col-span-2 border border-[#808080] rounded-md">
<SwitchListItem
data-testid="ai-enabled"
label="启用同传字幕AI总结"
hint="此功能将采用通义大模型对同传字幕进行总结"
value={state.enabled}
onChange={checker('enabled')}
hint="此功能将采用大语言模型对同传字幕进行总结"
value={state.summarizeEnabled}
onChange={checker('summarizeEnabled')}
marker={<ExperienmentFeatureIcon />}
/>
</List>
{state.enabled && (
<Fragment>
<Selector<typeof state.provider>
className="col-span-2"
data-testid="ai-provider"
label="技术来源"
value={state.provider}
onChange={e => state.provider = e}
options={[
{ label: 'Cloudflare AI', value: 'qwen' },
{ label: '有限度服务器', value: 'worker' },
{ label: 'Chrome 浏览器内置 AI', value: 'nano' }
]}
/>
{state.provider === 'qwen' && (
<Fragment>
<Typography
className="flex items-center gap-1 font-normal dark:text-gray-200 col-span-2"
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
className="-mt-px h-6 w-6"
>
<path
fillRule="evenodd"
d="M2.25 12c0-5.385 4.365-9.75 9.75-9.75s9.75 4.365 9.75 9.75-4.365 9.75-9.75 9.75S2.25 17.385 2.25 12zm8.706-1.442c1.146-.573 2.437.463 2.126 1.706l-.709 2.836.042-.02a.75.75 0 01.67 1.34l-.04.022c-1.147.573-2.438-.463-2.127-1.706l.71-2.836-.042.02a.75.75 0 11-.671-1.34l.041-.022zM12 9a.75.75 0 100-1.5.75.75 0 000 1.5z"
clipRule="evenodd"
/>
</svg>
<Typography className="underline" as="a" href="https://linux.do/t/topic/34037" target="_blank">点击此处</Typography>
查看如何获得 Cloudflare API Token 和 Account ID
</Typography>
<Input
data-testid="cf-account-id"
crossOrigin="anonymous"
variant="static"
required
label="Cloudflare Account ID"
value={state.accountId}
onChange={handler('accountId')}
/>
<Input
data-testid="cf-api-token"
crossOrigin="anonymous"
variant="static"
required
label="Cloudflare API Token"
value={state.apiToken}
onChange={handler('apiToken')}
/>
</Fragment>
)}
</Fragment>
)}
<div className="col-span-2">
<Button disabled={validating} onClick={onValidate} color="blue" size="lg" className="group flex items-center justify-center gap-3 text-[1rem] hover:shadow-lg">
验证是否可用
<Tooltip content="检查你目前的配置是否可用。若不可用,则无法启用AI总结功能。" placement="top-end">
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" strokeWidth={1.5} stroke="currentColor" className="size-6">
<path strokeLinecap="round" strokeLinejoin="round" d="m11.25 11.25.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Zm-9-3.75h.008v.008H12V8.25Z" />
</svg>
</Tooltip>
</Button>
</div>
</Fragment>
)
}
Expand Down
2 changes: 2 additions & 0 deletions src/options/fragments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import * as display from './fragments/display'
import * as features from './fragments/features'
import * as listings from './fragments/listings'
import * as version from './fragments/version'
import * as llm from './fragments/llm'


interface SettingFragment<T extends object> {
Expand All @@ -28,6 +29,7 @@ const fragments = {
'settings.listings': listings,
'settings.capture': capture,
'settings.display': display,
'settings.llm': llm,
'settings.developer': developer,
'settings.version': version
}
Expand Down
Loading

0 comments on commit 4d5e611

Please sign in to comment.