diff --git a/backend/src/chat/chat.resolver.ts b/backend/src/chat/chat.resolver.ts index 431c991..57fc8bf 100644 --- a/backend/src/chat/chat.resolver.ts +++ b/backend/src/chat/chat.resolver.ts @@ -53,7 +53,7 @@ export class ChatResolver { MessageRole.User, ); - const iterator = this.chatProxyService.streamChat(input.message); + const iterator = this.chatProxyService.streamChat(input); let accumulatedContent = ''; for await (const chunk of iterator) { diff --git a/backend/src/chat/chat.service.ts b/backend/src/chat/chat.service.ts index 0e5579b..e1b356a 100644 --- a/backend/src/chat/chat.service.ts +++ b/backend/src/chat/chat.service.ts @@ -5,7 +5,11 @@ import { Message, MessageRole } from 'src/chat/message.model'; import { InjectRepository } from '@nestjs/typeorm'; import { Repository } from 'typeorm'; import { User } from 'src/user/user.model'; -import { NewChatInput, UpdateChatTitleInput } from 'src/chat/dto/chat.input'; +import { + ChatInput, + NewChatInput, + UpdateChatTitleInput, +} from 'src/chat/dto/chat.input'; type CustomAsyncIterableIterator = AsyncIterator & { [Symbol.asyncIterator](): AsyncIterableIterator; @@ -17,8 +21,12 @@ export class ChatProxyService { constructor(private httpService: HttpService) {} - streamChat(input: string): CustomAsyncIterableIterator { - this.logger.debug('request chat input: ' + input); + streamChat( + input: ChatInput, + ): CustomAsyncIterableIterator { + this.logger.debug( + `Request chat input: ${input.message} with model: ${input.model}`, + ); let isDone = false; let responseSubscription: any; const chunkQueue: ChatCompletionChunk[] = []; @@ -60,7 +68,7 @@ export class ChatProxyService { responseSubscription = this.httpService .post( 'http://localhost:3001/chat/completion', - { content: input }, + { content: input.message, model: input.model }, { responseType: 'stream' }, ) .subscribe({ diff --git a/llm-server/src/llm-provider.ts b/llm-server/src/llm-provider.ts index 2286a9d..f47f999 100644 --- a/llm-server/src/llm-provider.ts +++ b/llm-server/src/llm-provider.ts @@ -3,6 +3,7 @@ import { ModelProvider } from './model/model-provider'; import { OpenAIModelProvider } from './model/openai-model-provider'; import { LlamaModelProvider } from './model/llama-model-provider'; import { Logger } from '@nestjs/common'; +import { GenerateMessageParams } from './type/GenerateMessage'; export interface ChatMessageInput { content: string; @@ -32,10 +33,10 @@ export class LLMProvider { } async generateStreamingResponse( - content: string, + params: GenerateMessageParams, res: Response, ): Promise { - await this.modelProvider.generateStreamingResponse(content, res); + await this.modelProvider.generateStreamingResponse(params, res); } async getModelTags(res: Response): Promise { diff --git a/llm-server/src/main.ts b/llm-server/src/main.ts index f062f05..fbc7391 100644 --- a/llm-server/src/main.ts +++ b/llm-server/src/main.ts @@ -1,6 +1,7 @@ import { Logger } from '@nestjs/common'; import { ChatMessageInput, LLMProvider } from './llm-provider'; import express, { Express, Request, Response } from 'express'; +import { GenerateMessageParams } from './type/GenerateMessage'; export class App { private readonly logger = new Logger(App.name); @@ -27,13 +28,22 @@ export class App { this.logger.log('Received chat request.'); try { this.logger.debug(JSON.stringify(req.body)); - const { content } = req.body as ChatMessageInput; + const { content, model } = req.body as ChatMessageInput & { + model: string; + }; + + const params: GenerateMessageParams = { + model: model || 'gpt-3.5-turbo', // Default to 'gpt-3.5-turbo' if model is not provided + message: content, + role: 'user', + }; + this.logger.debug(`Request content: "${content}"`); res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Cache-Control', 'no-cache'); res.setHeader('Connection', 'keep-alive'); this.logger.debug('Response headers set for streaming.'); - await this.llmProvider.generateStreamingResponse(content, res); + await this.llmProvider.generateStreamingResponse(params, res); } catch (error) { this.logger.error('Error in chat endpoint:', error); res.status(500).json({ error: 'Internal server error' }); diff --git a/llm-server/src/model/llama-model-provider.ts b/llm-server/src/model/llama-model-provider.ts index 07a24b7..2b25159 100644 --- a/llm-server/src/model/llama-model-provider.ts +++ b/llm-server/src/model/llama-model-provider.ts @@ -8,6 +8,9 @@ import { } from 'node-llama-cpp'; import { ModelProvider } from './model-provider.js'; import { Logger } from '@nestjs/common'; +import { systemPrompts } from '../prompt/systemPrompt'; +import { ChatCompletionMessageParam } from 'openai/resources/chat/completions'; +import { GenerateMessageParams } from '../type/GenerateMessage'; //TODO: using protocol class export class LlamaModelProvider extends ModelProvider { @@ -33,7 +36,7 @@ export class LlamaModelProvider extends ModelProvider { } async generateStreamingResponse( - content: string, + { model, message, role = 'user' }: GenerateMessageParams, res: Response, ): Promise { this.logger.log('Generating streaming response with Llama...'); @@ -43,8 +46,22 @@ export class LlamaModelProvider extends ModelProvider { this.logger.log('LlamaChatSession created.'); let chunkCount = 0; const startTime = Date.now(); + + // Get the system prompt based on the model + const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || ''; + + const messages = [ + { role: 'system', content: systemPrompt }, + { role: role as 'user' | 'system' | 'assistant', content: message }, + ]; + + // Convert messages array to a single formatted string for Llama + const formattedPrompt = messages + .map(({ role, content }) => `${role}: ${content}`) + .join('\n'); + try { - await session.prompt(content, { + await session.prompt(formattedPrompt, { onTextChunk: chunk => { chunkCount++; this.logger.debug(`Sending chunk #${chunkCount}: "${chunk}"`); diff --git a/llm-server/src/model/model-provider.ts b/llm-server/src/model/model-provider.ts index 07f6a0b..4d82329 100644 --- a/llm-server/src/model/model-provider.ts +++ b/llm-server/src/model/model-provider.ts @@ -1,9 +1,10 @@ import { Response } from 'express'; +import { GenerateMessageParams } from '../type/GenerateMessage'; export abstract class ModelProvider { abstract initialize(): Promise; abstract generateStreamingResponse( - content: string, + params: GenerateMessageParams, res: Response, ): Promise; diff --git a/llm-server/src/model/openai-model-provider.ts b/llm-server/src/model/openai-model-provider.ts index f48c30f..93c990c 100644 --- a/llm-server/src/model/openai-model-provider.ts +++ b/llm-server/src/model/openai-model-provider.ts @@ -2,6 +2,10 @@ import { Response } from 'express'; import OpenAI from 'openai'; import { ModelProvider } from './model-provider'; import { Logger } from '@nestjs/common'; +import { systemPrompts } from '../prompt/systemPrompt'; +import { ChatCompletionMessageParam } from 'openai/resources/chat/completions'; +import { GenerateMessageParams } from '../type/GenerateMessage'; + export class OpenAIModelProvider extends ModelProvider { private readonly logger = new Logger(OpenAIModelProvider.name); private openai: OpenAI; @@ -15,23 +19,34 @@ export class OpenAIModelProvider extends ModelProvider { } async generateStreamingResponse( - content: string, + { model, message, role = 'user' }: GenerateMessageParams, res: Response, ): Promise { this.logger.log('Generating streaming response with OpenAI...'); const startTime = Date.now(); + // Set SSE headers res.writeHead(200, { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', Connection: 'keep-alive', }); + + // Get the system prompt based on the model + const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || ''; + + const messages: ChatCompletionMessageParam[] = [ + { role: 'system', content: systemPrompt }, + { role: role as 'user' | 'system' | 'assistant', content: message }, + ]; + try { const stream = await this.openai.chat.completions.create({ - model: 'gpt-3.5-turbo', - messages: [{ role: 'user', content: content }], + model, + messages, stream: true, }); + let chunkCount = 0; for await (const chunk of stream) { const content = chunk.choices[0]?.delta?.content || ''; @@ -41,6 +56,7 @@ export class OpenAIModelProvider extends ModelProvider { res.write(`data: ${JSON.stringify(chunk)}\n\n`); } } + const endTime = Date.now(); this.logger.log( `Response generation completed. Total chunks: ${chunkCount}`, @@ -59,20 +75,18 @@ export class OpenAIModelProvider extends ModelProvider { async getModelTagsResponse(res: Response): Promise { this.logger.log('Fetching available models from OpenAI...'); - // Set SSE headers res.writeHead(200, { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', Connection: 'keep-alive', }); + try { const startTime = Date.now(); const models = await this.openai.models.list(); - const response = { - models: models, // Wrap the models in the required structure + models: models, }; - const endTime = Date.now(); this.logger.log( `Model fetching completed. Total models: ${models.data.length}`, diff --git a/llm-server/src/prompt/systemPrompt.ts b/llm-server/src/prompt/systemPrompt.ts new file mode 100644 index 0000000..f1edd0e --- /dev/null +++ b/llm-server/src/prompt/systemPrompt.ts @@ -0,0 +1,7 @@ +// Define and export the system prompts object +export const systemPrompts = { + 'codefox-basic': { + systemPrompt: `You are CodeFox, an advanced and powerful AI specialized in code generation and software engineering. + Your purpose is to help developers build complete and efficient applications by providing well-structured, optimized, and maintainable code.`, + }, +}; diff --git a/llm-server/src/type/GenerateMessage.ts b/llm-server/src/type/GenerateMessage.ts new file mode 100644 index 0000000..c7d8f6d --- /dev/null +++ b/llm-server/src/type/GenerateMessage.ts @@ -0,0 +1,5 @@ +export interface GenerateMessageParams { + model: string; // Model to use, e.g., 'gpt-3.5-turbo' + message: string; // User's message or query + role?: 'user' | 'system' | 'assistant' | 'tool' | 'function'; // Optional role +}