Skip to content

Commit

Permalink
feat: adding model downloader to llmserver and load apikey dynamically
Browse files Browse the repository at this point in the history
  • Loading branch information
NarwhalChen committed Jan 12, 2025
1 parent a78d34f commit 4dcb9ff
Show file tree
Hide file tree
Showing 14 changed files with 15,543 additions and 11,977 deletions.
145 changes: 43 additions & 102 deletions backend/src/config/config-loader.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
import * as fs from 'fs';
import * as _ from 'lodash';
import { getConfigPath } from './common-path';
import { ConfigType } from 'src/downloader/universal-utils';
import { Logger } from '@nestjs/common';
import * as path from 'path';

export interface ModelConfig {
model: string;
endpoint?: string;
token?: string;
default?: boolean;
task?: string;
}

export interface EmbeddingConfig {
model: string;
endpoint?: string;
Expand All @@ -21,72 +12,50 @@ export interface EmbeddingConfig {
}

export interface AppConfig {
models?: ModelConfig[];
embeddings?: EmbeddingConfig[];
}

export const exampleConfigContent = `{
// Chat models configuration
// You can configure multiple chat models
"models": [
// Example of OpenAI GPT configuration
// Embedding model configuration
// You can configure multiple embedding models
"embeddings": [
{
"model": "gpt-3.5-turbo",
"model": "text-embedding-ada-002",
"endpoint": "https://api.openai.com/v1",
"token": "your-openai-token", // Replace with your OpenAI token
"default": true // Set as default chat model
},
// Example of local model configuration
{
"model": "llama2",
"endpoint": "http://localhost:11434/v1"
"token": "your-openai-token", // Replace with your OpenAI token
"default": true // Set as default embedding
}
],
// Embedding model configuration (optional)
"embeddings": [{
"model": "text-embedding-ada-002",
"endpoint": "https://api.openai.com/v1",
"token": "your-openai-token", // Replace with your OpenAI token
"default": true // Set as default embedding
}]
]
}`;

export class ConfigLoader {
readonly logger = new Logger(ConfigLoader.name);
private type: string;
private static instances: Map<ConfigType, ConfigLoader> = new Map();
private static instance: ConfigLoader;
private static config: AppConfig;
private readonly configPath: string;

private constructor(type: ConfigType) {
this.type = type;
private constructor() {
this.configPath = getConfigPath();
this.initConfigFile();
this.loadConfig();
}

public static getInstance(type: ConfigType): ConfigLoader {
if (!ConfigLoader.instances.has(type)) {
ConfigLoader.instances.set(type, new ConfigLoader(type));
public static getInstance(): ConfigLoader {
if (!ConfigLoader.instance) {
ConfigLoader.instance = new ConfigLoader();
}
return ConfigLoader.instances.get(type)!;
return ConfigLoader.instance;
}

public initConfigFile(): void {
Logger.log('Creating example config file', 'ConfigLoader');
this.logger.log('Initializing configuration file', 'ConfigLoader');

const config = getConfigPath();
if (fs.existsSync(config)) {
if (fs.existsSync(this.configPath)) {
return;
}

if (!fs.existsSync(config)) {
//make file
fs.writeFileSync(config, exampleConfigContent, 'utf-8');
}
Logger.log('Creating example config file', 'ConfigLoader');
fs.writeFileSync(this.configPath, exampleConfigContent, 'utf-8');
this.logger.log('Example configuration file created', 'ConfigLoader');
}

public reload(): void {
Expand All @@ -95,10 +64,7 @@ export class ConfigLoader {

private loadConfig() {
try {
Logger.log(
`Loading configuration from ${this.configPath}`,
'ConfigLoader',
);
this.logger.log(`Loading configuration from ${this.configPath}`, 'ConfigLoader');
const file = fs.readFileSync(this.configPath, 'utf-8');
const jsonContent = file.replace(
/\\"|"(?:\\"|[^"])*"|(\/\/.*|\/\*[\s\S]*?\*\/)/g,
Expand Down Expand Up @@ -145,49 +111,48 @@ export class ConfigLoader {
);
}

addConfig(config: ModelConfig | EmbeddingConfig) {
if (!ConfigLoader.config[this.type]) {
ConfigLoader.config[this.type] = [];
addConfig(config: EmbeddingConfig) {
if (!ConfigLoader.config.embeddings) {
ConfigLoader.config.embeddings = [];
}
this.logger.log(ConfigLoader.config);
const index = ConfigLoader.config[this.type].findIndex(
(chat) => chat.model === config.model,

const index = ConfigLoader.config.embeddings.findIndex(
(emb) => emb.model === config.model,
);
if (index !== -1) {
ConfigLoader.config[this.type].splice(index, 1);
ConfigLoader.config.embeddings.splice(index, 1);
}

if (config.default) {
ConfigLoader.config.models.forEach((chat) => {
chat.default = false;
ConfigLoader.config.embeddings.forEach((emb) => {
emb.default = false;
});
}

ConfigLoader.config[this.type].push(config);
ConfigLoader.config.embeddings.push(config);
this.saveConfig();
}

removeConfig(modelName: string): boolean {
if (!ConfigLoader.config[this.type]) {
if (!ConfigLoader.config.embeddings) {
return false;
}

const initialLength = ConfigLoader.config[this.type].length;
ConfigLoader.config.models = ConfigLoader.config[this.type].filter(
(chat) => chat.model !== modelName,
const initialLength = ConfigLoader.config.embeddings.length;
ConfigLoader.config.embeddings = ConfigLoader.config.embeddings.filter(
(emb) => emb.model !== modelName,
);

if (ConfigLoader.config[this.type].length !== initialLength) {
if (ConfigLoader.config.embeddings.length !== initialLength) {
this.saveConfig();
return true;
}

return false;
}

getAllConfigs(): EmbeddingConfig[] | ModelConfig[] | null {
const res = ConfigLoader.config[this.type];
return Array.isArray(res) ? res : null;
getAllConfigs(): EmbeddingConfig[] | null {
return ConfigLoader.config.embeddings || null;
}

validateConfig() {
Expand All @@ -199,49 +164,25 @@ export class ConfigLoader {
throw new Error('Invalid configuration: Must be an object');
}

if (ConfigLoader.config.models) {
if (!Array.isArray(ConfigLoader.config.models)) {
throw new Error("Invalid configuration: 'chats' must be an array");
}

ConfigLoader.config.models.forEach((chat, index) => {
if (!chat.model) {
throw new Error(
`Invalid chat configuration at index ${index}: 'model' is required`,
);
}
});

const defaultChats = ConfigLoader.config.models.filter(
(chat) => chat.default,
);
if (defaultChats.length > 1) {
throw new Error(
'Invalid configuration: Multiple default chat configurations found',
);
}
}

if (ConfigLoader.config[ConfigType.EMBEDDINGS]) {
this.logger.log(ConfigLoader.config[ConfigType.EMBEDDINGS]);
if (!Array.isArray(ConfigLoader.config[ConfigType.EMBEDDINGS])) {
if (ConfigLoader.config.embeddings) {
if (!Array.isArray(ConfigLoader.config.embeddings)) {
throw new Error("Invalid configuration: 'embeddings' must be an array");
}

ConfigLoader.config.models.forEach((emb, index) => {
ConfigLoader.config.embeddings.forEach((emb, index) => {
if (!emb.model) {
throw new Error(
`Invalid chat configuration at index ${index}: 'model' is required`,
`Invalid embedding configuration at index ${index}: 'model' is required`,
);
}
});

const defaultChats = ConfigLoader.config[ConfigType.EMBEDDINGS].filter(
(chat) => chat.default,
const defaultEmbeddings = ConfigLoader.config.embeddings.filter(
(emb) => emb.default,
);
if (defaultChats.length > 1) {
if (defaultEmbeddings.length > 1) {
throw new Error(
'Invalid configuration: Multiple default emb configurations found',
'Invalid configuration: Multiple default embedding configurations found',
);
}
}
Expand Down
85 changes: 24 additions & 61 deletions backend/src/downloader/__tests__/loadAllChatsModels.spec.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import path from 'path';
import * as fs from 'fs';
import {
ConfigLoader,
ModelConfig,
EmbeddingConfig,
} from '../../config/config-loader';
import { UniversalDownloader } from '../model-downloader';
import { ConfigType, downloadAll, TaskType } from '../universal-utils';

import { ConfigLoader, EmbeddingConfig } from '../../config/config-loader';
import { EmbeddingDownloader } from '../embedding-downloader';
import { downloadAllEmbeddings } from '../universal-utils';

const originalIsArray = Array.isArray;

Expand All @@ -22,76 +17,44 @@ Array.isArray = jest.fn((type: any): type is any[] => {
return originalIsArray(type);
}) as unknown as (arg: any) => arg is any[];

// jest.mock('../../config/config-loader', () => {
// return {
// ConfigLoader: jest.fn().mockImplementation(() => {
// return {
// get: jest.fn().mockReturnValue({
// chat1: {
// model: 'Felladrin/onnx-flan-alpaca-base',
// task: 'text2text-generation',
// },
// }),
// validateConfig: jest.fn(),
// };
// }),
// };
// });

describe('loadAllChatsModels with real model loading', () => {
let modelConfigLoader: ConfigLoader;
describe('loadAllEmbeddingModels with real model loading', () => {
let embConfigLoader: ConfigLoader;
beforeAll(async () => {
modelConfigLoader = ConfigLoader.getInstance(ConfigType.CHATS);
embConfigLoader = ConfigLoader.getInstance(ConfigType.EMBEDDINGS);
const modelConfig: ModelConfig = {
model: 'Xenova/flan-t5-small',
endpoint: 'http://localhost:11434/v1',
token: 'your-token-here',
task: 'text2text-generation',
};
modelConfigLoader.addConfig(modelConfig);

beforeAll(async () => {
embConfigLoader = ConfigLoader.getInstance();
const embConfig: EmbeddingConfig = {
model: 'fast-bge-base-en-v1.5',
endpoint: 'http://localhost:11434/v1',
token: 'your-token-here',
};
embConfigLoader.addConfig(embConfig);

console.log('preload starts');
await downloadAll();
await downloadAllEmbeddings();
console.log('preload successfully');
}, 60000000);
}, 6000000);

it('should load real models specified in config', async () => {
const downloader = UniversalDownloader.getInstance();
const chat1Model = await downloader.getLocalModel(
TaskType.CHAT,
'Xenova/flan-t5-small',
it('should load real embedding models specified in config', async () => {
const downloader = EmbeddingDownloader.getInstance();
const embeddingModel = await downloader.getPipeline(
'fast-bge-base-en-v1.5',
);
expect(chat1Model).toBeDefined();
console.log('Loaded Model:', chat1Model);
expect(embeddingModel).toBeDefined();
console.log('Loaded Embedding Model:', embeddingModel);

expect(chat1Model).toHaveProperty('model');
expect(chat1Model).toHaveProperty('tokenizer');
expect(embeddingModel).toHaveProperty('model');

try {
const chat1Output = await chat1Model(
'Write me a love poem about cheese.',
{
max_new_tokens: 200,
temperature: 0.9,
repetition_penalty: 2.0,
no_repeat_ngram_size: 3,
},
const embeddingOutput = await embeddingModel.embed(
['Test input sentence for embedding.']
);
for await (const batch of embeddingOutput) {
console.log(batch);
}

console.log('Model Output:', chat1Output);
expect(chat1Output).toBeDefined();
expect(chat1Output[0]).toHaveProperty('generated_text');
console.log(chat1Output[0].generated_text);
expect(embeddingOutput).toBeDefined();
} catch (error) {
console.error('Error during model inference:', error);
console.error('Error during embedding model inference:', error);
}
}, 6000000);
});
Expand Down
Loading

0 comments on commit 4dcb9ff

Please sign in to comment.