-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(backend): Backend config convention and pull model #54
Changes from 22 commits
7635a8a
c6fde4a
1deca3d
e69b220
618c20e
f1382eb
96a55d1
71b662d
a4f692d
458db40
bf37046
5a6f4bf
ed2e534
3ab6bc8
3fe86b4
a2c07a6
f97b97d
eb55281
9c34f08
ed07140
8a311f0
28d5f78
c1c8a98
3ea4222
4e59695
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,55 @@ | ||||||||||||||||||||||||||||||||||||||||||||
import * as fs from 'fs'; | ||||||||||||||||||||||||||||||||||||||||||||
import * as path from 'path'; | ||||||||||||||||||||||||||||||||||||||||||||
import * as _ from 'lodash'; | ||||||||||||||||||||||||||||||||||||||||||||
import { getConfigPath } from './common-path'; | ||||||||||||||||||||||||||||||||||||||||||||
export interface ChatConfig { | ||||||||||||||||||||||||||||||||||||||||||||
model: string; | ||||||||||||||||||||||||||||||||||||||||||||
endpoint?: string; | ||||||||||||||||||||||||||||||||||||||||||||
token?: string; | ||||||||||||||||||||||||||||||||||||||||||||
default?: boolean; | ||||||||||||||||||||||||||||||||||||||||||||
task?: string; | ||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
export class ConfigLoader { | ||||||||||||||||||||||||||||||||||||||||||||
private chatsConfig: ChatConfig[]; | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
private readonly configPath: string; | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
constructor() { | ||||||||||||||||||||||||||||||||||||||||||||
this.configPath = getConfigPath('config'); | ||||||||||||||||||||||||||||||||||||||||||||
this.loadConfig(); | ||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
private loadConfig() { | ||||||||||||||||||||||||||||||||||||||||||||
const file = fs.readFileSync(this.configPath, 'utf-8'); | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
this.chatsConfig = JSON.parse(file); | ||||||||||||||||||||||||||||||||||||||||||||
console.log('Raw file content:', this.chatsConfig); | ||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+23
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling and consider async file operations The - private loadConfig() {
+ private loadConfig(): void {
+ try {
const file = fs.readFileSync(this.configPath, 'utf-8');
this.chatsConfig = JSON.parse(file);
console.log('Raw file content:', this.chatsConfig);
+ } catch (error) {
+ console.error('Failed to load config:', error);
+ this.chatsConfig = [];
+ throw new Error(`Failed to load configuration: ${error.message}`);
+ }
} 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
get<T>(path: string) { | ||||||||||||||||||||||||||||||||||||||||||||
if (!path) { | ||||||||||||||||||||||||||||||||||||||||||||
return this.chatsConfig as unknown as T; | ||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||
return _.get(this.chatsConfig, path) as T; | ||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+30
to
+35
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Improve type safety in get method The generic get method could be made more type-safe by constraining the return type based on the path. - get<T>(path: string) {
+ get<T extends keyof ChatConfig>(path: T): ChatConfig[T];
+ get<T>(path: string): T;
if (!path) {
- return this.chatsConfig as unknown as T;
+ return this.chatsConfig as T;
}
return _.get(this.chatsConfig, path) as T;
}
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
set(path: string, value: any) { | ||||||||||||||||||||||||||||||||||||||||||||
_.set(this.chatsConfig, path, value); | ||||||||||||||||||||||||||||||||||||||||||||
this.saveConfig(); | ||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
private saveConfig() { | ||||||||||||||||||||||||||||||||||||||||||||
fs.writeFileSync( | ||||||||||||||||||||||||||||||||||||||||||||
this.configPath, | ||||||||||||||||||||||||||||||||||||||||||||
JSON.stringify(this.chatsConfig, null, 4), | ||||||||||||||||||||||||||||||||||||||||||||
'utf-8', | ||||||||||||||||||||||||||||||||||||||||||||
); | ||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
validateConfig() { | ||||||||||||||||||||||||||||||||||||||||||||
if (!this.chatsConfig) { | ||||||||||||||||||||||||||||||||||||||||||||
throw new Error("Invalid configuration: 'chats' section is missing."); | ||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+50
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Enhance configuration validation The current validation only checks if validateConfig() {
if (!this.chatsConfig) {
throw new Error("Invalid configuration: 'chats' section is missing.");
}
+ if (!Array.isArray(this.chatsConfig)) {
+ throw new Error("Invalid configuration: 'chats' must be an array.");
+ }
+ for (const config of this.chatsConfig) {
+ if (!config.model) {
+ throw new Error("Invalid configuration: 'model' is required.");
+ }
+ if (!config.task) {
+ throw new Error("Invalid configuration: 'task' is required.");
+ }
+ }
} 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import path from 'path'; | ||
import * as fs from 'fs'; | ||
import { ConfigLoader } from '../../config/config-loader'; | ||
import { ModelDownloader } from '../model-downloader'; | ||
import { downloadAllModels } from '../utils'; | ||
import { getConfigDir, getConfigPath } from 'src/config/common-path'; | ||
|
||
const originalIsArray = Array.isArray; | ||
|
||
Array.isArray = jest.fn((type: any): type is any[] => { | ||
if ( | ||
type && | ||
type.constructor && | ||
(type.constructor.name === 'Float32Array' || | ||
type.constructor.name === 'BigInt64Array') | ||
) { | ||
return true; | ||
} | ||
return originalIsArray(type); | ||
}) as unknown as (arg: any) => arg is any[]; | ||
Comment on lines
+8
to
+20
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid modifying global Array.isArray Modifying the global - const originalIsArray = Array.isArray;
- Array.isArray = jest.fn((type: any): type is any[] => {
+ const isCustomArray = (type: any): type is any[] => {
if (
type &&
type.constructor &&
(type.constructor.name === 'Float32Array' ||
type.constructor.name === 'BigInt64Array')
) {
return true;
}
- return originalIsArray(type);
+ return Array.isArray(type);
- }) as unknown as (arg: any) => arg is any[];
+ };
🧰 Tools🪛 Biome (1.9.4)[error] 12-13: Change to an optional chain. Unsafe fix: Change to an optional chain. (lint/complexity/useOptionalChain) |
||
|
||
// jest.mock('../../config/config-loader', () => { | ||
// return { | ||
// ConfigLoader: jest.fn().mockImplementation(() => { | ||
// return { | ||
// get: jest.fn().mockReturnValue({ | ||
// chat1: { | ||
// model: 'Felladrin/onnx-flan-alpaca-base', | ||
// task: 'text2text-generation', | ||
// }, | ||
// }), | ||
// validateConfig: jest.fn(), | ||
// }; | ||
// }), | ||
// }; | ||
// }); | ||
|
||
describe('loadAllChatsModels with real model loading', () => { | ||
let configLoader: ConfigLoader; | ||
beforeAll(async () => { | ||
const testConfig = [ | ||
{ | ||
model: 'Felladrin/onnx-flan-alpaca-base', | ||
task: 'text2text-generation', | ||
} | ||
]; | ||
const configPath = getConfigPath('config'); | ||
fs.writeFileSync(configPath, JSON.stringify(testConfig, null, 2), 'utf8'); | ||
|
||
configLoader = new ConfigLoader(); | ||
await downloadAllModels(); | ||
}, 600000); | ||
|
||
it('should load real models specified in config', async () => { | ||
const downloader = ModelDownloader.getInstance(); | ||
|
||
const chat1Model = await downloader.getLocalModel( | ||
'text2text-generation', | ||
'Felladrin/onnx-flan-alpaca-base', | ||
); | ||
expect(chat1Model).toBeDefined(); | ||
console.log('Loaded Model:', chat1Model); | ||
|
||
expect(chat1Model).toHaveProperty('model'); | ||
expect(chat1Model).toHaveProperty('tokenizer'); | ||
|
||
try { | ||
const chat1Output = await chat1Model( | ||
'Write me a love poem about cheese.', | ||
{ | ||
max_new_tokens: 200, | ||
temperature: 0.9, | ||
repetition_penalty: 2.0, | ||
no_repeat_ngram_size: 3, | ||
}, | ||
); | ||
|
||
console.log('Model Output:', chat1Output); | ||
expect(chat1Output).toBeDefined(); | ||
expect(chat1Output[0]).toHaveProperty('generated_text'); | ||
console.log(chat1Output[0].generated_text); | ||
} catch (error) { | ||
console.error('Error during model inference:', error); | ||
} | ||
}, 600000); | ||
}); | ||
|
||
afterAll(() => { | ||
Array.isArray = originalIsArray; | ||
}); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import { Logger } from '@nestjs/common'; | ||
import { PipelineType, pipeline, env } from '@huggingface/transformers'; | ||
import { getModelPath, getModelsDir } from 'src/config/common-path'; | ||
env.allowLocalModels = true; | ||
env.localModelPath = getModelsDir(); | ||
export class ModelDownloader { | ||
readonly logger = new Logger(ModelDownloader.name); | ||
private static instance: ModelDownloader; | ||
public static getInstance(): ModelDownloader { | ||
if (!ModelDownloader.instance) { | ||
ModelDownloader.instance = new ModelDownloader(); | ||
} | ||
return ModelDownloader.instance; | ||
} | ||
Comment on lines
+9
to
+14
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Improve singleton implementation for thread safety The current singleton implementation might have race conditions in a concurrent environment. Consider using a more robust singleton pattern: - private static instance: ModelDownloader;
- public static getInstance(): ModelDownloader {
- if (!ModelDownloader.instance) {
- ModelDownloader.instance = new ModelDownloader();
- }
- return ModelDownloader.instance;
- }
+ private static instance: ModelDownloader | null = null;
+ private static instanceLock = false;
+ public static getInstance(): ModelDownloader {
+ if (!ModelDownloader.instance) {
+ if (ModelDownloader.instanceLock) {
+ throw new Error('Instance creation in progress');
+ }
+ ModelDownloader.instanceLock = true;
+ try {
+ ModelDownloader.instance = new ModelDownloader();
+ } finally {
+ ModelDownloader.instanceLock = false;
+ }
+ }
+ return ModelDownloader.instance;
+ }
|
||
|
||
async downloadModel(task: string, model: string): Promise<any> { | ||
const pipelineInstance = await pipeline(task as PipelineType, model, { | ||
cache_dir: getModelsDir(), | ||
}); | ||
return pipelineInstance; | ||
} | ||
|
||
public async getLocalModel(task: string, model: string): Promise<any> { | ||
const pipelineInstance = await pipeline(task as PipelineType, model, { | ||
local_files_only: true, | ||
revision: 'local', | ||
}); | ||
|
||
return pipelineInstance; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import { ChatConfig, ConfigLoader } from 'src/config/config-loader'; | ||
import { ModelDownloader } from './model-downloader'; | ||
|
||
export async function downloadAllModels(): Promise<void> { | ||
const configLoader = new ConfigLoader(); | ||
configLoader.validateConfig(); | ||
const chats = configLoader.get<ChatConfig[]>(''); | ||
const downloader = ModelDownloader.getInstance(); | ||
console.log('Loaded config:', chats); | ||
const loadPromises = chats.map(async (chatConfig: ChatConfig) => { | ||
const { model, task } = chatConfig; | ||
try { | ||
downloader.logger.log(model, task); | ||
const pipelineInstance = await downloader.downloadModel(task, model); | ||
} catch (error) { | ||
downloader.logger.error(`Failed to load model ${model}:`, error.message); | ||
} | ||
}); | ||
Comment on lines
+10
to
+18
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling and type safety improvements Several issues need attention in the model loading logic:
Consider this improved implementation: - const loadPromises = chats.map(async (chatConfig: ChatConfig) => {
+ type SupportedTask = 'text-generation' | 'sentiment-analysis'; // add all supported tasks
+ const loadPromises = chats.map(async (chatConfig: ChatConfig) => {
const { model, task } = chatConfig;
+ if (!isValidTask(task)) {
+ throw new Error(`Unsupported task: ${task}`);
+ }
try {
downloader.logger.log(model, task);
- const pipelineInstance = await downloader.downloadModel(task, model);
+ await downloader.downloadModel(task as SupportedTask, model);
} catch (error) {
downloader.logger.error(`Failed to load model ${model}:`, error.message);
+ throw error; // re-throw to indicate failure
}
});
- await Promise.all(loadPromises);
+ const results = await Promise.allSettled(loadPromises);
+ const failures = results.filter(r => r.status === 'rejected');
+ if (failures.length > 0) {
+ throw new Error(`Failed to load ${failures.length} models`);
+ }
|
||
await Promise.all(loadPromises); | ||
|
||
downloader.logger.log('All models loaded.'); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pin Hugging Face dependencies to specific versions
Using "latest" for production dependencies is risky as it can lead to unexpected breaking changes. Consider pinning to specific versions.