Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): Backend config convention and pull model #54

Merged
merged 25 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7635a8a
feat(backend): backend config convention
NarwhalChen Oct 31, 2024
c6fde4a
feat: adding pull model function with test
NarwhalChen Nov 18, 2024
1deca3d
fix: adding default type
NarwhalChen Nov 28, 2024
e69b220
fix: fix the problem of prepare a unready config to other service
NarwhalChen Nov 28, 2024
618c20e
fix: File naming is more semantic
NarwhalChen Nov 30, 2024
f1382eb
Merge remote-tracking branch 'origin/main' into backend-config-conven…
NarwhalChen Dec 10, 2024
96a55d1
[autofix.ci] apply automated fixes
autofix-ci[bot] Dec 10, 2024
71b662d
fix: fixing bug in modeldowloader that cannot pull model
NarwhalChen Dec 10, 2024
a4f692d
fix: fixing bug in modeldowloader that cannot pull model
NarwhalChen Dec 10, 2024
458db40
Merge branch 'backend-config-convention' of https://github.com/Sma1lb…
NarwhalChen Dec 10, 2024
bf37046
[autofix.ci] apply automated fixes
autofix-ci[bot] Dec 10, 2024
5a6f4bf
to: makes model download when project start and download to local folder
NarwhalChen Dec 12, 2024
ed2e534
to: can load model both locally and remotely
NarwhalChen Dec 13, 2024
3ab6bc8
fix: fixing the layer structure of chatsconfig
NarwhalChen Dec 13, 2024
3fe86b4
fix: fixing the layer structure of chatsconfig and relative bug in Lo…
NarwhalChen Dec 13, 2024
a2c07a6
[autofix.ci] apply automated fixes
autofix-ci[bot] Dec 13, 2024
f97b97d
fix: fixing the layer structure of chatsconfig and relative bug in Lo…
NarwhalChen Dec 15, 2024
eb55281
Merge branch 'backend-config-convention' of https://github.com/Sma1lb…
NarwhalChen Dec 15, 2024
9c34f08
refactor: rename ConfigLoader and ModelLoader files, update imports, …
Sma1lboy Dec 15, 2024
ed07140
fix: updating layer structure of chatconfig and updating relative test
NarwhalChen Dec 15, 2024
8a311f0
Delete .editorconfig
Sma1lboy Dec 16, 2024
28d5f78
Delete pnpm-lock.yaml
Sma1lboy Dec 16, 2024
c1c8a98
Delete pnpm-lock.yaml
Sma1lboy Dec 16, 2024
3ea4222
Merge branch 'main' into backend-config-convention
Sma1lboy Dec 16, 2024
4e59695
Fix merge conflicts in pnpm-lock.yaml and update axios version to 1.7.8
Sma1lboy Dec 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
},
"dependencies": {
"@apollo/server": "^4.11.0",
"@huggingface/hub": "latest",
"@huggingface/transformers": "latest",
Comment on lines +29 to +30
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Pin Hugging Face dependencies to specific versions

Using "latest" for production dependencies is risky as it can lead to unexpected breaking changes. Consider pinning to specific versions.

- "@huggingface/hub": "latest",
- "@huggingface/transformers": "latest",
+ "@huggingface/hub": "^0.14.1",
+ "@huggingface/transformers": "^2.15.0",

Committable suggestion skipped: line range outside the PR's diff.

"@nestjs/apollo": "^12.2.0",
"@nestjs/axios": "^3.0.3",
"@nestjs/common": "^10.0.0",
Expand All @@ -36,11 +38,13 @@
"@nestjs/platform-express": "^10.0.0",
"@nestjs/typeorm": "^10.0.2",
"@types/bcrypt": "^5.0.2",
"axios": "^1.7.7",
"@types/fs-extra": "^11.0.4",
"bcrypt": "^5.1.1",
"class-validator": "^0.14.1",
"fs-extra": "^11.2.0",
"graphql": "^16.9.0",
"lodash": "^4.17.21",
"graphql-subscriptions": "^2.0.0",
"graphql-ws": "^5.16.0",
"markdown-to-txt": "^2.0.1",
Expand Down
444 changes: 438 additions & 6 deletions backend/pnpm-lock.yaml

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions backend/src/config/common-path.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import path from 'path';
import os from 'os';
import * as path from 'path';
import * as os from 'os';
import { existsSync, mkdirSync, promises } from 'fs-extra';
import { createHash } from 'crypto';

Expand Down
55 changes: 55 additions & 0 deletions backend/src/config/config-loader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import * as fs from 'fs';
import * as path from 'path';
import * as _ from 'lodash';
import { getConfigPath } from './common-path';
export interface ChatConfig {
model: string;
endpoint?: string;
token?: string;
default?: boolean;
task?: string;
}

export class ConfigLoader {
private chatsConfig: ChatConfig[];

private readonly configPath: string;

constructor() {
this.configPath = getConfigPath('config');
this.loadConfig();
}

private loadConfig() {
const file = fs.readFileSync(this.configPath, 'utf-8');

this.chatsConfig = JSON.parse(file);
console.log('Raw file content:', this.chatsConfig);
}
Comment on lines +23 to +28
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling and consider async file operations

The loadConfig method uses synchronous file operations and lacks error handling. This could cause issues in production.

- private loadConfig() {
+ private loadConfig(): void {
+   try {
      const file = fs.readFileSync(this.configPath, 'utf-8');
      this.chatsConfig = JSON.parse(file);
      console.log('Raw file content:', this.chatsConfig);
+   } catch (error) {
+     console.error('Failed to load config:', error);
+     this.chatsConfig = [];
+     throw new Error(`Failed to load configuration: ${error.message}`);
+   }
  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
private loadConfig() {
const file = fs.readFileSync(this.configPath, 'utf-8');
this.chatsConfig = JSON.parse(file);
console.log('Raw file content:', this.chatsConfig);
}
private loadConfig(): void {
try {
const file = fs.readFileSync(this.configPath, 'utf-8');
this.chatsConfig = JSON.parse(file);
console.log('Raw file content:', this.chatsConfig);
} catch (error) {
console.error('Failed to load config:', error);
this.chatsConfig = [];
throw new Error(`Failed to load configuration: ${error.message}`);
}
}


get<T>(path: string) {
if (!path) {
return this.chatsConfig as unknown as T;
}
return _.get(this.chatsConfig, path) as T;
}
Comment on lines +30 to +35
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve type safety in get method

The generic get method could be made more type-safe by constraining the return type based on the path.

- get<T>(path: string) {
+ get<T extends keyof ChatConfig>(path: T): ChatConfig[T];
+ get<T>(path: string): T;
  if (!path) {
-   return this.chatsConfig as unknown as T;
+   return this.chatsConfig as T;
  }
  return _.get(this.chatsConfig, path) as T;
}

Committable suggestion skipped: line range outside the PR's diff.


set(path: string, value: any) {
_.set(this.chatsConfig, path, value);
this.saveConfig();
}

private saveConfig() {
fs.writeFileSync(
this.configPath,
JSON.stringify(this.chatsConfig, null, 4),
'utf-8',
);
}

validateConfig() {
if (!this.chatsConfig) {
throw new Error("Invalid configuration: 'chats' section is missing.");
}
}
Comment on lines +50 to +54
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance configuration validation

The current validation only checks if chatsConfig exists. Consider adding more comprehensive validation.

 validateConfig() {
   if (!this.chatsConfig) {
     throw new Error("Invalid configuration: 'chats' section is missing.");
   }
+  if (!Array.isArray(this.chatsConfig)) {
+    throw new Error("Invalid configuration: 'chats' must be an array.");
+  }
+  for (const config of this.chatsConfig) {
+    if (!config.model) {
+      throw new Error("Invalid configuration: 'model' is required.");
+    }
+    if (!config.task) {
+      throw new Error("Invalid configuration: 'task' is required.");
+    }
+  }
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
validateConfig() {
if (!this.chatsConfig) {
throw new Error("Invalid configuration: 'chats' section is missing.");
}
}
validateConfig() {
if (!this.chatsConfig) {
throw new Error("Invalid configuration: 'chats' section is missing.");
}
if (!Array.isArray(this.chatsConfig)) {
throw new Error("Invalid configuration: 'chats' must be an array.");
}
for (const config of this.chatsConfig) {
if (!config.model) {
throw new Error("Invalid configuration: 'model' is required.");
}
if (!config.task) {
throw new Error("Invalid configuration: 'task' is required.");
}
}
}

}
3 changes: 3 additions & 0 deletions backend/src/main.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { NestFactory } from '@nestjs/core';
import { AppModule } from './app.module';
import 'reflect-metadata';
import { downloadAllModels } from './model/utils';

async function bootstrap() {
const app = await NestFactory.create(AppModule);
Expand All @@ -16,6 +17,8 @@ async function bootstrap() {
'Access-Control-Allow-Credentials',
],
});
await downloadAllModels();
await app.listen(process.env.PORT ?? 3000);
}

bootstrap();
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Test, TestingModule } from '@nestjs/testing';
import { INestApplication } from '@nestjs/common';
import * as request from 'supertest';
import { AppModule } from '../src/app.module';
import { AppModule } from '../../app.module';

describe('AppController (e2e)', () => {
let app: INestApplication;
Expand Down
File renamed without changes.
90 changes: 90 additions & 0 deletions backend/src/model/__tests__/loadAllChatsModels.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import path from 'path';
import * as fs from 'fs';
import { ConfigLoader } from '../../config/config-loader';
import { ModelDownloader } from '../model-downloader';
import { downloadAllModels } from '../utils';
import { getConfigDir, getConfigPath } from 'src/config/common-path';

const originalIsArray = Array.isArray;

Array.isArray = jest.fn((type: any): type is any[] => {
if (
type &&
type.constructor &&
(type.constructor.name === 'Float32Array' ||
type.constructor.name === 'BigInt64Array')
) {
return true;
}
return originalIsArray(type);
}) as unknown as (arg: any) => arg is any[];
Comment on lines +8 to +20
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid modifying global Array.isArray

Modifying the global Array.isArray is risky and could affect other tests. Consider using a more isolated approach.

- const originalIsArray = Array.isArray;
- Array.isArray = jest.fn((type: any): type is any[] => {
+ const isCustomArray = (type: any): type is any[] => {
   if (
     type &&
     type.constructor &&
     (type.constructor.name === 'Float32Array' ||
       type.constructor.name === 'BigInt64Array')
   ) {
     return true;
   }
-  return originalIsArray(type);
+  return Array.isArray(type);
- }) as unknown as (arg: any) => arg is any[];
+ };

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Biome (1.9.4)

[error] 12-13: Change to an optional chain.

Unsafe fix: Change to an optional chain.

(lint/complexity/useOptionalChain)


// jest.mock('../../config/config-loader', () => {
// return {
// ConfigLoader: jest.fn().mockImplementation(() => {
// return {
// get: jest.fn().mockReturnValue({
// chat1: {
// model: 'Felladrin/onnx-flan-alpaca-base',
// task: 'text2text-generation',
// },
// }),
// validateConfig: jest.fn(),
// };
// }),
// };
// });

describe('loadAllChatsModels with real model loading', () => {
let configLoader: ConfigLoader;
beforeAll(async () => {
const testConfig = [
{
model: 'Felladrin/onnx-flan-alpaca-base',
task: 'text2text-generation',
}
];
const configPath = getConfigPath('config');
fs.writeFileSync(configPath, JSON.stringify(testConfig, null, 2), 'utf8');

configLoader = new ConfigLoader();
await downloadAllModels();
}, 600000);

it('should load real models specified in config', async () => {
const downloader = ModelDownloader.getInstance();

const chat1Model = await downloader.getLocalModel(
'text2text-generation',
'Felladrin/onnx-flan-alpaca-base',
);
expect(chat1Model).toBeDefined();
console.log('Loaded Model:', chat1Model);

expect(chat1Model).toHaveProperty('model');
expect(chat1Model).toHaveProperty('tokenizer');

try {
const chat1Output = await chat1Model(
'Write me a love poem about cheese.',
{
max_new_tokens: 200,
temperature: 0.9,
repetition_penalty: 2.0,
no_repeat_ngram_size: 3,
},
);

console.log('Model Output:', chat1Output);
expect(chat1Output).toBeDefined();
expect(chat1Output[0]).toHaveProperty('generated_text');
console.log(chat1Output[0].generated_text);
} catch (error) {
console.error('Error during model inference:', error);
}
}, 600000);
});

afterAll(() => {
Array.isArray = originalIsArray;
});
31 changes: 31 additions & 0 deletions backend/src/model/model-downloader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { Logger } from '@nestjs/common';
import { PipelineType, pipeline, env } from '@huggingface/transformers';
import { getModelPath, getModelsDir } from 'src/config/common-path';
env.allowLocalModels = true;
env.localModelPath = getModelsDir();
export class ModelDownloader {
readonly logger = new Logger(ModelDownloader.name);
private static instance: ModelDownloader;
public static getInstance(): ModelDownloader {
if (!ModelDownloader.instance) {
ModelDownloader.instance = new ModelDownloader();
}
return ModelDownloader.instance;
}
Comment on lines +9 to +14
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve singleton implementation for thread safety

The current singleton implementation might have race conditions in a concurrent environment.

Consider using a more robust singleton pattern:

- private static instance: ModelDownloader;
- public static getInstance(): ModelDownloader {
-   if (!ModelDownloader.instance) {
-     ModelDownloader.instance = new ModelDownloader();
-   }
-   return ModelDownloader.instance;
- }
+ private static instance: ModelDownloader | null = null;
+ private static instanceLock = false;
+ public static getInstance(): ModelDownloader {
+   if (!ModelDownloader.instance) {
+     if (ModelDownloader.instanceLock) {
+       throw new Error('Instance creation in progress');
+     }
+     ModelDownloader.instanceLock = true;
+     try {
+       ModelDownloader.instance = new ModelDownloader();
+     } finally {
+       ModelDownloader.instanceLock = false;
+     }
+   }
+   return ModelDownloader.instance;
+ }

Committable suggestion skipped: line range outside the PR's diff.


async downloadModel(task: string, model: string): Promise<any> {
const pipelineInstance = await pipeline(task as PipelineType, model, {
cache_dir: getModelsDir(),
});
return pipelineInstance;
}

public async getLocalModel(task: string, model: string): Promise<any> {
const pipelineInstance = await pipeline(task as PipelineType, model, {
local_files_only: true,
revision: 'local',
});

return pipelineInstance;
}
}
22 changes: 22 additions & 0 deletions backend/src/model/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { ChatConfig, ConfigLoader } from 'src/config/config-loader';
import { ModelDownloader } from './model-downloader';

export async function downloadAllModels(): Promise<void> {
const configLoader = new ConfigLoader();
configLoader.validateConfig();
const chats = configLoader.get<ChatConfig[]>('');
const downloader = ModelDownloader.getInstance();
console.log('Loaded config:', chats);
const loadPromises = chats.map(async (chatConfig: ChatConfig) => {
const { model, task } = chatConfig;
try {
downloader.logger.log(model, task);
const pipelineInstance = await downloader.downloadModel(task, model);
} catch (error) {
downloader.logger.error(`Failed to load model ${model}:`, error.message);
}
});
Comment on lines +10 to +18
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling and type safety improvements

Several issues need attention in the model loading logic:

  1. The unused pipelineInstance variable
  2. No type safety for the task parameter
  3. Errors are caught but the function continues silently

Consider this improved implementation:

- const loadPromises = chats.map(async (chatConfig: ChatConfig) => {
+ type SupportedTask = 'text-generation' | 'sentiment-analysis'; // add all supported tasks
+ const loadPromises = chats.map(async (chatConfig: ChatConfig) => {
    const { model, task } = chatConfig;
+   if (!isValidTask(task)) {
+     throw new Error(`Unsupported task: ${task}`);
+   }
    try {
      downloader.logger.log(model, task);
-     const pipelineInstance = await downloader.downloadModel(task, model);
+     await downloader.downloadModel(task as SupportedTask, model);
    } catch (error) {
      downloader.logger.error(`Failed to load model ${model}:`, error.message);
+     throw error; // re-throw to indicate failure
    }
  });
- await Promise.all(loadPromises);
+ const results = await Promise.allSettled(loadPromises);
+ const failures = results.filter(r => r.status === 'rejected');
+ if (failures.length > 0) {
+   throw new Error(`Failed to load ${failures.length} models`);
+ }

Committable suggestion skipped: line range outside the PR's diff.

await Promise.all(loadPromises);

downloader.logger.log('All models loaded.');
}
Loading
Loading