Skip to content

Commit

Permalink
update to use system prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHallen122 committed Nov 3, 2024
1 parent 2baa14b commit c24518f
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 10 deletions.
5 changes: 3 additions & 2 deletions llm-server/src/llm-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { ModelProvider } from './model/model-provider';
import { OpenAIModelProvider } from './model/openai-model-provider';
import { LlamaModelProvider } from './model/llama-model-provider';
import { Logger } from '@nestjs/common';
import { GenerateMessageParams } from './type/GenerateMessage';

export interface ChatMessageInput {
content: string;
Expand Down Expand Up @@ -32,10 +33,10 @@ export class LLMProvider {
}

async generateStreamingResponse(
content: string,
params: GenerateMessageParams,
res: Response,
): Promise<void> {
await this.modelProvider.generateStreamingResponse(content, res);
await this.modelProvider.generateStreamingResponse(params, res);
}

async getModelTags(res: Response): Promise<void> {
Expand Down
14 changes: 12 additions & 2 deletions llm-server/src/main.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Logger } from '@nestjs/common';
import { ChatMessageInput, LLMProvider } from './llm-provider';
import express, { Express, Request, Response } from 'express';
import { GenerateMessageParams } from './type/GenerateMessage';

export class App {
private readonly logger = new Logger(App.name);
Expand All @@ -27,13 +28,22 @@ export class App {
this.logger.log('Received chat request.');
try {
this.logger.debug(JSON.stringify(req.body));
const { content } = req.body as ChatMessageInput;
const { content, model } = req.body as ChatMessageInput & {
model: string;
};

const params: GenerateMessageParams = {
model: model || 'gpt-3.5-turbo', // Default to 'gpt-3.5-turbo' if model is not provided
message: content,
role: 'user',
};

this.logger.debug(`Request content: "${content}"`);
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
this.logger.debug('Response headers set for streaming.');
await this.llmProvider.generateStreamingResponse(content, res);
await this.llmProvider.generateStreamingResponse(params, res);
} catch (error) {
this.logger.error('Error in chat endpoint:', error);
res.status(500).json({ error: 'Internal server error' });
Expand Down
5 changes: 3 additions & 2 deletions llm-server/src/model/llama-model-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
} from 'node-llama-cpp';
import { ModelProvider } from './model-provider.js';
import { Logger } from '@nestjs/common';
import { GenerateMessageParams } from '../type/GenerateMessage';

//TODO: using protocol class
export class LlamaModelProvider extends ModelProvider {
Expand All @@ -33,7 +34,7 @@ export class LlamaModelProvider extends ModelProvider {
}

async generateStreamingResponse(
content: string,
{ model, message, role = 'user' }: GenerateMessageParams,
res: Response,
): Promise<void> {
this.logger.log('Generating streaming response with Llama...');
Expand All @@ -44,7 +45,7 @@ export class LlamaModelProvider extends ModelProvider {
let chunkCount = 0;
const startTime = Date.now();
try {
await session.prompt(content, {
await session.prompt(message, {
onTextChunk: chunk => {
chunkCount++;
this.logger.debug(`Sending chunk #${chunkCount}: "${chunk}"`);
Expand Down
3 changes: 2 additions & 1 deletion llm-server/src/model/model-provider.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Response } from 'express';
import { GenerateMessageParams } from '../type/GenerateMessage';

export abstract class ModelProvider {
abstract initialize(): Promise<void>;
abstract generateStreamingResponse(
content: string,
params: GenerateMessageParams,
res: Response,
): Promise<void>;

Expand Down
20 changes: 17 additions & 3 deletions llm-server/src/model/openai-model-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ import { Response } from 'express';
import OpenAI from 'openai';
import { ModelProvider } from './model-provider';
import { Logger } from '@nestjs/common';
import { systemPrompts } from '../prompt/systemPrompt';
import { ChatCompletionMessageParam } from 'openai/resources/chat/completions';
import { GenerateMessageParams } from '../type/GenerateMessage';

export class OpenAIModelProvider extends ModelProvider {
private readonly logger = new Logger(OpenAIModelProvider.name);
private openai: OpenAI;
Expand All @@ -15,21 +19,31 @@ export class OpenAIModelProvider extends ModelProvider {
}

async generateStreamingResponse(
content: string,
{ model, message, role = 'user' }: GenerateMessageParams,
res: Response,
): Promise<void> {
this.logger.log('Generating streaming response with OpenAI...');
const startTime = Date.now();

// Set SSE headers
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
});

// Get the system prompt based on the model
const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || '';

// Prepare the messages array, including system prompt if available
const messages: ChatCompletionMessageParam[] = systemPrompt
? [{ role: 'system', content: systemPrompt }]
: [{ role: role as 'user' | 'system' | 'assistant', content: message }];

try {
const stream = await this.openai.chat.completions.create({
model: 'gpt-3.5-turbo',
messages: [{ role: 'user', content: content }],
model,
messages,
stream: true,
});
let chunkCount = 0;
Expand Down
5 changes: 5 additions & 0 deletions llm-server/src/type/GenerateMessage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export interface GenerateMessageParams {
model: string; // Model to use, e.g., 'gpt-3.5-turbo'
message: string; // User's message or query
role?: 'user' | 'system' | 'assistant' | 'tool' | 'function'; // Optional role
}

0 comments on commit c24518f

Please sign in to comment.