From df5fbe3a734d1b68fa6ba89339975b432cf37027 Mon Sep 17 00:00:00 2001 From: Jackson Chen <541898146chen@gmail.com> Date: Fri, 8 Nov 2024 04:31:55 -0600 Subject: [PATCH] feat: Add ProjectBuilderModule and ProjectBuilderService --- backend/src/chat/chat.resolver.ts | 2 +- backend/src/chat/chat.service.ts | 207 ++-------------- backend/src/chat/dto/chat.input.ts | 1 + backend/src/common/model-provider/index.ts | 267 +++++++++++++++++++++ backend/src/common/model-provider/types.ts | 7 + 5 files changed, 291 insertions(+), 193 deletions(-) create mode 100644 backend/src/common/model-provider/index.ts create mode 100644 backend/src/common/model-provider/types.ts diff --git a/backend/src/chat/chat.resolver.ts b/backend/src/chat/chat.resolver.ts index 431c9916..57fc8bfd 100644 --- a/backend/src/chat/chat.resolver.ts +++ b/backend/src/chat/chat.resolver.ts @@ -53,7 +53,7 @@ export class ChatResolver { MessageRole.User, ); - const iterator = this.chatProxyService.streamChat(input.message); + const iterator = this.chatProxyService.streamChat(input); let accumulatedContent = ''; for await (const chunk of iterator) { diff --git a/backend/src/chat/chat.service.ts b/backend/src/chat/chat.service.ts index 0e5579b1..8f5bc492 100644 --- a/backend/src/chat/chat.service.ts +++ b/backend/src/chat/chat.service.ts @@ -5,208 +5,31 @@ import { Message, MessageRole } from 'src/chat/message.model'; import { InjectRepository } from '@nestjs/typeorm'; import { Repository } from 'typeorm'; import { User } from 'src/user/user.model'; -import { NewChatInput, UpdateChatTitleInput } from 'src/chat/dto/chat.input'; - -type CustomAsyncIterableIterator = AsyncIterator & { - [Symbol.asyncIterator](): AsyncIterableIterator; -}; +import { + ChatInput, + NewChatInput, + UpdateChatTitleInput, +} from 'src/chat/dto/chat.input'; +import { CustomAsyncIterableIterator } from 'src/common/model-provider/types'; +import { ModelProvider } from 'src/common/model-provider'; @Injectable() export class ChatProxyService { private readonly logger = new Logger('ChatProxyService'); + private models: ModelProvider; - constructor(private httpService: HttpService) {} - - streamChat(input: string): CustomAsyncIterableIterator { - this.logger.debug('request chat input: ' + input); - let isDone = false; - let responseSubscription: any; - const chunkQueue: ChatCompletionChunk[] = []; - let resolveNextChunk: - | ((value: IteratorResult) => void) - | null = null; - - const iterator: CustomAsyncIterableIterator = { - next: () => { - return new Promise>((resolve) => { - if (chunkQueue.length > 0) { - resolve({ done: false, value: chunkQueue.shift()! }); - } else if (isDone) { - resolve({ done: true, value: undefined }); - } else { - resolveNextChunk = resolve; - } - }); - }, - return: () => { - isDone = true; - if (responseSubscription) { - responseSubscription.unsubscribe(); - } - return Promise.resolve({ done: true, value: undefined }); - }, - throw: (error) => { - isDone = true; - if (responseSubscription) { - responseSubscription.unsubscribe(); - } - return Promise.reject(error); - }, - [Symbol.asyncIterator]() { - return this; - }, - }; - - responseSubscription = this.httpService - .post( - 'http://localhost:3001/chat/completion', - { content: input }, - { responseType: 'stream' }, - ) - .subscribe({ - next: (response) => { - let buffer = ''; - response.data.on('data', (chunk: Buffer) => { - buffer += chunk.toString(); - let newlineIndex; - while ((newlineIndex = buffer.indexOf('\n')) !== -1) { - const line = buffer.slice(0, newlineIndex).trim(); - buffer = buffer.slice(newlineIndex + 1); - - if (line.startsWith('data: ')) { - const jsonStr = line.slice(6); - // TODO: don't remove rn - if (jsonStr === '[DONE]') { - return; - } - // if (jsonStr === '[DONE]') { - // const doneChunk: ChatCompletionChunk = { - // id: 'done', - // object: 'chat.completion.chunk', - // created: Date.now(), - // model: '', - // systemFingerprint: null, - // choices: [], - // status: StreamStatus.DONE, - // }; - - // if (resolveNextChunk) { - // resolveNextChunk({ done: false, value: doneChunk }); - // resolveNextChunk = null; - // } else { - // chunkQueue.push(doneChunk); - // } - // return; - // } - try { - const parsed = JSON.parse(jsonStr); - if (this.isValidChunk(parsed)) { - const parsedChunk: ChatCompletionChunk = { - ...parsed, - status: StreamStatus.STREAMING, - }; - - if (resolveNextChunk) { - resolveNextChunk({ done: false, value: parsedChunk }); - resolveNextChunk = null; - } else { - chunkQueue.push(parsedChunk); - } - } else { - this.logger.warn('Invalid chunk received:', parsed); - } - } catch (error) { - this.logger.error('Error parsing chunk:', error); - } - } - } - }); - response.data.on('end', () => { - this.logger.debug('Stream ended'); - if (!isDone) { - const doneChunk: ChatCompletionChunk = { - id: 'done', - object: 'chat.completion.chunk', - created: Date.now(), - model: 'gpt-3.5-turbo', - systemFingerprint: null, - choices: [], - status: StreamStatus.DONE, - }; - - if (resolveNextChunk) { - resolveNextChunk({ done: false, value: doneChunk }); - resolveNextChunk = null; - } else { - chunkQueue.push(doneChunk); - } - } - - setTimeout(() => { - isDone = true; - if (resolveNextChunk) { - resolveNextChunk({ done: true, value: undefined }); - resolveNextChunk = null; - } - }, 0); - }); - }, - error: (error) => { - this.logger.error('Error in stream:', error); - const doneChunk: ChatCompletionChunk = { - id: 'done', - object: 'chat.completion.chunk', - created: Date.now(), - model: 'gpt-3.5-turbo', - systemFingerprint: null, - choices: [], - status: StreamStatus.DONE, - }; - - if (resolveNextChunk) { - resolveNextChunk({ done: false, value: doneChunk }); - setTimeout(() => { - isDone = true; - resolveNextChunk?.({ done: true, value: undefined }); - resolveNextChunk = null; - }, 0); - } else { - chunkQueue.push(doneChunk); - setTimeout(() => { - isDone = true; - }, 0); - } - }, - }); - - return iterator; + constructor(private httpService: HttpService) { + this.models = ModelProvider.getInstance(); } - private isValidChunk(chunk: any): boolean { - return ( - chunk && - typeof chunk.id === 'string' && - typeof chunk.object === 'string' && - typeof chunk.created === 'number' && - typeof chunk.model === 'string' - ); + streamChat( + input: ChatInput, + ): CustomAsyncIterableIterator { + return this.models.chat(input.message, input.model, input.chatId); } async fetchModelTags(): Promise { - try { - this.logger.debug('Requesting model tags from /tags endpoint.'); - - // Make a GET request to /tags - const response = await this.httpService - .get('http://localhost:3001/tags', { responseType: 'json' }) - .toPromise(); - - this.logger.debug('Model tags received:', response.data); - return response.data; - } catch (error) { - this.logger.error('Error fetching model tags:', error); - throw new Error('Failed to fetch model tags'); - } + return this.models.fetchModelsName(); } } diff --git a/backend/src/chat/dto/chat.input.ts b/backend/src/chat/dto/chat.input.ts index 2a3bde2a..feeb738c 100644 --- a/backend/src/chat/dto/chat.input.ts +++ b/backend/src/chat/dto/chat.input.ts @@ -16,6 +16,7 @@ export class UpdateChatTitleInput { title: string; } +// TODO: using ChatInput in model-provider.ts @InputType('ChatInputType') export class ChatInput { @Field() diff --git a/backend/src/common/model-provider/index.ts b/backend/src/common/model-provider/index.ts new file mode 100644 index 00000000..12f133a8 --- /dev/null +++ b/backend/src/common/model-provider/index.ts @@ -0,0 +1,267 @@ +import { Logger } from '@nestjs/common'; +import { HttpService } from '@nestjs/axios'; +import { ChatCompletionChunk, StreamStatus } from 'src/chat/chat.model'; + +export interface ChatInput { + content: string; + attachments?: Array<{ + type: string; + content: string | Buffer; + name?: string; + }>; + contextLength?: number; + temperature?: number; +} + +export interface ModelProviderConfig { + endpoint: string; + defaultModel?: string; +} + +export interface CustomAsyncIterableIterator extends AsyncIterator { + [Symbol.asyncIterator](): AsyncIterableIterator; +} + +export class ModelProvider { + private readonly logger = new Logger('ModelProvider'); + private isDone = false; + private responseSubscription: any; + private chunkQueue: ChatCompletionChunk[] = []; + private resolveNextChunk: + | ((value: IteratorResult) => void) + | null = null; + + private static instance: ModelProvider | undefined = undefined; + + public static getInstance() { + if (this.instance) { + return this.instance; + } + + return new ModelProvider(new HttpService(), { + // TODO: adding into env + endpoint: 'http://localhost:3001', + }); + } + + constructor( + private readonly httpService: HttpService, + private readonly config: ModelProviderConfig, + ) {} + + chat( + input: ChatInput | string, + model?: string, + chatId?: string, + ): CustomAsyncIterableIterator { + const chatInput = this.normalizeChatInput(input); + const selectedModel = model || this.config.defaultModel || undefined; + if (selectedModel === undefined) { + this.logger.error('No model selected for chat request'); + return; + } + + this.logger.debug( + `Chat request - Model: ${selectedModel}, ChatId: ${chatId || 'N/A'}`, + { input: chatInput }, + ); + + const iterator: CustomAsyncIterableIterator = { + next: () => this.handleNext(), + return: () => this.handleReturn(), + throw: (error) => this.handleThrow(error), + [Symbol.asyncIterator]() { + return this; + }, + }; + + this.startChat(chatInput, selectedModel, chatId); + return iterator; + } + + private normalizeChatInput(input: ChatInput | string): ChatInput { + if (typeof input === 'string') { + return { content: input }; + } + return input; + } + + private handleNext(): Promise> { + return new Promise>((resolve) => { + if (this.chunkQueue.length > 0) { + resolve({ done: false, value: this.chunkQueue.shift()! }); + } else if (this.isDone) { + resolve({ done: true, value: undefined }); + } else { + this.resolveNextChunk = resolve; + } + }); + } + + private handleReturn(): Promise> { + this.cleanup(); + return Promise.resolve({ done: true, value: undefined }); + } + + private handleThrow( + error: any, + ): Promise> { + this.cleanup(); + return Promise.reject(error); + } + + private cleanup() { + this.isDone = true; + if (this.responseSubscription) { + this.responseSubscription.unsubscribe(); + } + } + + private createRequestPayload( + input: ChatInput, + model: string, + chatId?: string, + ) { + return { + ...input, + model, + ...(chatId && { chatId }), + }; + } + + private createDoneChunk(model: string): ChatCompletionChunk { + return { + id: 'done', + object: 'chat.completion.chunk', + created: Date.now(), + model, + systemFingerprint: null, + choices: [], + status: StreamStatus.DONE, + }; + } + + private handleChunk(chunk: any) { + if (this.isValidChunk(chunk)) { + const parsedChunk: ChatCompletionChunk = { + ...chunk, + status: StreamStatus.STREAMING, + }; + + if (this.resolveNextChunk) { + this.resolveNextChunk({ done: false, value: parsedChunk }); + this.resolveNextChunk = null; + } else { + this.chunkQueue.push(parsedChunk); + } + } else { + this.logger.warn('Invalid chunk received:', chunk); + } + } + + private handleStreamEnd(model: string) { + this.logger.debug('Stream ended'); + if (!this.isDone) { + const doneChunk = this.createDoneChunk(model); + if (this.resolveNextChunk) { + this.resolveNextChunk({ done: false, value: doneChunk }); + this.resolveNextChunk = null; + } else { + this.chunkQueue.push(doneChunk); + } + } + + setTimeout(() => { + this.isDone = true; + if (this.resolveNextChunk) { + this.resolveNextChunk({ done: true, value: undefined }); + this.resolveNextChunk = null; + } + }, 0); + } + + private handleStreamError(error: any, model: string) { + this.logger.error('Error in stream:', error); + const doneChunk = this.createDoneChunk(model); + + if (this.resolveNextChunk) { + this.resolveNextChunk({ done: false, value: doneChunk }); + setTimeout(() => { + this.isDone = true; + this.resolveNextChunk?.({ done: true, value: undefined }); + this.resolveNextChunk = null; + }, 0); + } else { + this.chunkQueue.push(doneChunk); + setTimeout(() => { + this.isDone = true; + }, 0); + } + } + + private startChat(input: ChatInput, model: string, chatId?: string) { + const payload = this.createRequestPayload(input, model, chatId); + + this.responseSubscription = this.httpService + .post(`${this.config.endpoint}/chat/completion`, payload, { + responseType: 'stream', + headers: { + 'Content-Type': 'application/json', + }, + }) + .subscribe({ + next: (response) => { + let buffer = ''; + response.data.on('data', (chunk: Buffer) => { + buffer += chunk.toString(); + let newlineIndex; + while ((newlineIndex = buffer.indexOf('\n')) !== -1) { + const line = buffer.slice(0, newlineIndex).trim(); + buffer = buffer.slice(newlineIndex + 1); + + if (line.startsWith('data: ')) { + const jsonStr = line.slice(6); + if (jsonStr === '[DONE]') { + return; + } + try { + const parsed = JSON.parse(jsonStr); + this.handleChunk(parsed); + } catch (error) { + this.logger.error('Error parsing chunk:', error); + } + } + } + }); + response.data.on('end', () => this.handleStreamEnd(model)); + }, + error: (error) => this.handleStreamError(error, model), + }); + } + + private isValidChunk(chunk: any): boolean { + return ( + chunk && + typeof chunk.id === 'string' && + typeof chunk.object === 'string' && + typeof chunk.created === 'number' && + typeof chunk.model === 'string' + ); + } + + public async fetchModelsName() { + try { + this.logger.debug('Requesting model tags from /tags endpoint.'); + + // Make a GET request to /tags + const response = await this.httpService + .get(`${this.config.endpoint}/tags`, { responseType: 'json' }) + .toPromise(); + this.logger.debug('Model tags received:', response.data); + return response.data; + } catch (error) { + this.logger.error('Error fetching model tags:', error); + throw new Error('Failed to fetch model tags'); + } + } +} diff --git a/backend/src/common/model-provider/types.ts b/backend/src/common/model-provider/types.ts new file mode 100644 index 00000000..8c649851 --- /dev/null +++ b/backend/src/common/model-provider/types.ts @@ -0,0 +1,7 @@ +export interface ModelChatStreamConfig { + endpoint: string; + model?: string; +} +export type CustomAsyncIterableIterator = AsyncIterator & { + [Symbol.asyncIterator](): AsyncIterableIterator; +};