Skip to content

Commit

Permalink
feat(backend): Backend config convention and pull model (#54)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
	- Introduced a configuration management system for chat settings.
	- Added functionality to download and manage machine learning models.
- Implemented a method to download multiple models based on
configurations.

- **Bug Fixes**
- Updated dependencies to enhance functionality and improve data
handling.

- **Tests**
- Added comprehensive tests for loading chat models and verifying their
functionality.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Jackson Chen <[email protected]>
Co-authored-by: Jackson Chen <[email protected]>
  • Loading branch information
4 people authored Dec 16, 2024
1 parent 471e9d2 commit cf95728
Show file tree
Hide file tree
Showing 11 changed files with 11,082 additions and 8,146 deletions.
4 changes: 4 additions & 0 deletions backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
},
"dependencies": {
"@apollo/server": "^4.11.0",
"@huggingface/hub": "latest",
"@huggingface/transformers": "latest",
"@nestjs/apollo": "^12.2.0",
"@nestjs/axios": "^3.0.3",
"@nestjs/common": "^10.0.0",
Expand All @@ -36,12 +38,14 @@
"@nestjs/platform-express": "^10.0.0",
"@nestjs/typeorm": "^10.0.2",
"@types/bcrypt": "^5.0.2",
"axios": "^1.7.7",
"@types/fs-extra": "^11.0.4",
"@types/normalize-path": "^3.0.2",
"bcrypt": "^5.1.1",
"class-validator": "^0.14.1",
"fs-extra": "^11.2.0",
"graphql": "^16.9.0",
"lodash": "^4.17.21",
"graphql-subscriptions": "^2.0.0",
"graphql-ws": "^5.16.0",
"markdown-to-txt": "^2.0.1",
Expand Down
444 changes: 438 additions & 6 deletions backend/pnpm-lock.yaml

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions backend/src/config/common-path.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import path from 'path';
import os from 'os';
import * as path from 'path';
import * as os from 'os';
import { existsSync, mkdirSync, promises } from 'fs-extra';
import { createHash } from 'crypto';

Expand Down
55 changes: 55 additions & 0 deletions backend/src/config/config-loader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import * as fs from 'fs';
import * as path from 'path';
import * as _ from 'lodash';
import { getConfigPath } from './common-path';
export interface ChatConfig {
model: string;
endpoint?: string;
token?: string;
default?: boolean;
task?: string;
}

export class ConfigLoader {
private chatsConfig: ChatConfig[];

private readonly configPath: string;

constructor() {
this.configPath = getConfigPath('config');
this.loadConfig();
}

private loadConfig() {
const file = fs.readFileSync(this.configPath, 'utf-8');

this.chatsConfig = JSON.parse(file);
console.log('Raw file content:', this.chatsConfig);
}

get<T>(path: string) {
if (!path) {
return this.chatsConfig as unknown as T;
}
return _.get(this.chatsConfig, path) as T;
}

set(path: string, value: any) {
_.set(this.chatsConfig, path, value);
this.saveConfig();
}

private saveConfig() {
fs.writeFileSync(
this.configPath,
JSON.stringify(this.chatsConfig, null, 4),
'utf-8',
);
}

validateConfig() {
if (!this.chatsConfig) {
throw new Error("Invalid configuration: 'chats' section is missing.");
}
}
}
3 changes: 3 additions & 0 deletions backend/src/main.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { NestFactory } from '@nestjs/core';
import { AppModule } from './app.module';
import 'reflect-metadata';
import { downloadAllModels } from './model/utils';

async function bootstrap() {
const app = await NestFactory.create(AppModule);
Expand All @@ -16,6 +17,8 @@ async function bootstrap() {
'Access-Control-Allow-Credentials',
],
});
await downloadAllModels();
await app.listen(process.env.PORT ?? 3000);
}

bootstrap();
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Test, TestingModule } from '@nestjs/testing';
import { INestApplication } from '@nestjs/common';
import * as request from 'supertest';
import { AppModule } from '../src/app.module';
import { AppModule } from '../../app.module';

describe('AppController (e2e)', () => {
let app: INestApplication;
Expand Down
File renamed without changes.
90 changes: 90 additions & 0 deletions backend/src/model/__tests__/loadAllChatsModels.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import path from 'path';
import * as fs from 'fs';
import { ConfigLoader } from '../../config/config-loader';
import { ModelDownloader } from '../model-downloader';
import { downloadAllModels } from '../utils';
import { getConfigDir, getConfigPath } from 'src/config/common-path';

const originalIsArray = Array.isArray;

Array.isArray = jest.fn((type: any): type is any[] => {
if (
type &&
type.constructor &&
(type.constructor.name === 'Float32Array' ||
type.constructor.name === 'BigInt64Array')
) {
return true;
}
return originalIsArray(type);
}) as unknown as (arg: any) => arg is any[];

// jest.mock('../../config/config-loader', () => {
// return {
// ConfigLoader: jest.fn().mockImplementation(() => {
// return {
// get: jest.fn().mockReturnValue({
// chat1: {
// model: 'Felladrin/onnx-flan-alpaca-base',
// task: 'text2text-generation',
// },
// }),
// validateConfig: jest.fn(),
// };
// }),
// };
// });

describe('loadAllChatsModels with real model loading', () => {
let configLoader: ConfigLoader;
beforeAll(async () => {
const testConfig = [
{
model: 'Felladrin/onnx-flan-alpaca-base',
task: 'text2text-generation',
}
];
const configPath = getConfigPath('config');
fs.writeFileSync(configPath, JSON.stringify(testConfig, null, 2), 'utf8');

configLoader = new ConfigLoader();
await downloadAllModels();
}, 600000);

it('should load real models specified in config', async () => {
const downloader = ModelDownloader.getInstance();

const chat1Model = await downloader.getLocalModel(
'text2text-generation',
'Felladrin/onnx-flan-alpaca-base',
);
expect(chat1Model).toBeDefined();
console.log('Loaded Model:', chat1Model);

expect(chat1Model).toHaveProperty('model');
expect(chat1Model).toHaveProperty('tokenizer');

try {
const chat1Output = await chat1Model(
'Write me a love poem about cheese.',
{
max_new_tokens: 200,
temperature: 0.9,
repetition_penalty: 2.0,
no_repeat_ngram_size: 3,
},
);

console.log('Model Output:', chat1Output);
expect(chat1Output).toBeDefined();
expect(chat1Output[0]).toHaveProperty('generated_text');
console.log(chat1Output[0].generated_text);
} catch (error) {
console.error('Error during model inference:', error);
}
}, 600000);
});

afterAll(() => {
Array.isArray = originalIsArray;
});
31 changes: 31 additions & 0 deletions backend/src/model/model-downloader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { Logger } from '@nestjs/common';
import { PipelineType, pipeline, env } from '@huggingface/transformers';
import { getModelPath, getModelsDir } from 'src/config/common-path';
env.allowLocalModels = true;
env.localModelPath = getModelsDir();
export class ModelDownloader {
readonly logger = new Logger(ModelDownloader.name);
private static instance: ModelDownloader;
public static getInstance(): ModelDownloader {
if (!ModelDownloader.instance) {
ModelDownloader.instance = new ModelDownloader();
}
return ModelDownloader.instance;
}

async downloadModel(task: string, model: string): Promise<any> {
const pipelineInstance = await pipeline(task as PipelineType, model, {
cache_dir: getModelsDir(),
});
return pipelineInstance;
}

public async getLocalModel(task: string, model: string): Promise<any> {
const pipelineInstance = await pipeline(task as PipelineType, model, {
local_files_only: true,
revision: 'local',
});

return pipelineInstance;
}
}
22 changes: 22 additions & 0 deletions backend/src/model/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { ChatConfig, ConfigLoader } from 'src/config/config-loader';
import { ModelDownloader } from './model-downloader';

export async function downloadAllModels(): Promise<void> {
const configLoader = new ConfigLoader();
configLoader.validateConfig();
const chats = configLoader.get<ChatConfig[]>('');
const downloader = ModelDownloader.getInstance();
console.log('Loaded config:', chats);
const loadPromises = chats.map(async (chatConfig: ChatConfig) => {
const { model, task } = chatConfig;
try {
downloader.logger.log(model, task);
const pipelineInstance = await downloader.downloadModel(task, model);
} catch (error) {
downloader.logger.error(`Failed to load model ${model}:`, error.message);
}
});
await Promise.all(loadPromises);

downloader.logger.log('All models loaded.');
}
Loading

0 comments on commit cf95728

Please sign in to comment.