-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(backend): Backend config convention and pull model (#54)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a configuration management system for chat settings. - Added functionality to download and manage machine learning models. - Implemented a method to download multiple models based on configurations. - **Bug Fixes** - Updated dependencies to enhance functionality and improve data handling. - **Tests** - Added comprehensive tests for loading chat models and verifying their functionality. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Jackson Chen <[email protected]> Co-authored-by: Jackson Chen <[email protected]>
- Loading branch information
1 parent
471e9d2
commit cf95728
Showing
11 changed files
with
11,082 additions
and
8,146 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import * as fs from 'fs'; | ||
import * as path from 'path'; | ||
import * as _ from 'lodash'; | ||
import { getConfigPath } from './common-path'; | ||
export interface ChatConfig { | ||
model: string; | ||
endpoint?: string; | ||
token?: string; | ||
default?: boolean; | ||
task?: string; | ||
} | ||
|
||
export class ConfigLoader { | ||
private chatsConfig: ChatConfig[]; | ||
|
||
private readonly configPath: string; | ||
|
||
constructor() { | ||
this.configPath = getConfigPath('config'); | ||
this.loadConfig(); | ||
} | ||
|
||
private loadConfig() { | ||
const file = fs.readFileSync(this.configPath, 'utf-8'); | ||
|
||
this.chatsConfig = JSON.parse(file); | ||
console.log('Raw file content:', this.chatsConfig); | ||
} | ||
|
||
get<T>(path: string) { | ||
if (!path) { | ||
return this.chatsConfig as unknown as T; | ||
} | ||
return _.get(this.chatsConfig, path) as T; | ||
} | ||
|
||
set(path: string, value: any) { | ||
_.set(this.chatsConfig, path, value); | ||
this.saveConfig(); | ||
} | ||
|
||
private saveConfig() { | ||
fs.writeFileSync( | ||
this.configPath, | ||
JSON.stringify(this.chatsConfig, null, 4), | ||
'utf-8', | ||
); | ||
} | ||
|
||
validateConfig() { | ||
if (!this.chatsConfig) { | ||
throw new Error("Invalid configuration: 'chats' section is missing."); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
backend/test/app.e2e-spec.ts → backend/src/model/__tests__/app.e2e-spec.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import path from 'path'; | ||
import * as fs from 'fs'; | ||
import { ConfigLoader } from '../../config/config-loader'; | ||
import { ModelDownloader } from '../model-downloader'; | ||
import { downloadAllModels } from '../utils'; | ||
import { getConfigDir, getConfigPath } from 'src/config/common-path'; | ||
|
||
const originalIsArray = Array.isArray; | ||
|
||
Array.isArray = jest.fn((type: any): type is any[] => { | ||
if ( | ||
type && | ||
type.constructor && | ||
(type.constructor.name === 'Float32Array' || | ||
type.constructor.name === 'BigInt64Array') | ||
) { | ||
return true; | ||
} | ||
return originalIsArray(type); | ||
}) as unknown as (arg: any) => arg is any[]; | ||
|
||
// jest.mock('../../config/config-loader', () => { | ||
// return { | ||
// ConfigLoader: jest.fn().mockImplementation(() => { | ||
// return { | ||
// get: jest.fn().mockReturnValue({ | ||
// chat1: { | ||
// model: 'Felladrin/onnx-flan-alpaca-base', | ||
// task: 'text2text-generation', | ||
// }, | ||
// }), | ||
// validateConfig: jest.fn(), | ||
// }; | ||
// }), | ||
// }; | ||
// }); | ||
|
||
describe('loadAllChatsModels with real model loading', () => { | ||
let configLoader: ConfigLoader; | ||
beforeAll(async () => { | ||
const testConfig = [ | ||
{ | ||
model: 'Felladrin/onnx-flan-alpaca-base', | ||
task: 'text2text-generation', | ||
} | ||
]; | ||
const configPath = getConfigPath('config'); | ||
fs.writeFileSync(configPath, JSON.stringify(testConfig, null, 2), 'utf8'); | ||
|
||
configLoader = new ConfigLoader(); | ||
await downloadAllModels(); | ||
}, 600000); | ||
|
||
it('should load real models specified in config', async () => { | ||
const downloader = ModelDownloader.getInstance(); | ||
|
||
const chat1Model = await downloader.getLocalModel( | ||
'text2text-generation', | ||
'Felladrin/onnx-flan-alpaca-base', | ||
); | ||
expect(chat1Model).toBeDefined(); | ||
console.log('Loaded Model:', chat1Model); | ||
|
||
expect(chat1Model).toHaveProperty('model'); | ||
expect(chat1Model).toHaveProperty('tokenizer'); | ||
|
||
try { | ||
const chat1Output = await chat1Model( | ||
'Write me a love poem about cheese.', | ||
{ | ||
max_new_tokens: 200, | ||
temperature: 0.9, | ||
repetition_penalty: 2.0, | ||
no_repeat_ngram_size: 3, | ||
}, | ||
); | ||
|
||
console.log('Model Output:', chat1Output); | ||
expect(chat1Output).toBeDefined(); | ||
expect(chat1Output[0]).toHaveProperty('generated_text'); | ||
console.log(chat1Output[0].generated_text); | ||
} catch (error) { | ||
console.error('Error during model inference:', error); | ||
} | ||
}, 600000); | ||
}); | ||
|
||
afterAll(() => { | ||
Array.isArray = originalIsArray; | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import { Logger } from '@nestjs/common'; | ||
import { PipelineType, pipeline, env } from '@huggingface/transformers'; | ||
import { getModelPath, getModelsDir } from 'src/config/common-path'; | ||
env.allowLocalModels = true; | ||
env.localModelPath = getModelsDir(); | ||
export class ModelDownloader { | ||
readonly logger = new Logger(ModelDownloader.name); | ||
private static instance: ModelDownloader; | ||
public static getInstance(): ModelDownloader { | ||
if (!ModelDownloader.instance) { | ||
ModelDownloader.instance = new ModelDownloader(); | ||
} | ||
return ModelDownloader.instance; | ||
} | ||
|
||
async downloadModel(task: string, model: string): Promise<any> { | ||
const pipelineInstance = await pipeline(task as PipelineType, model, { | ||
cache_dir: getModelsDir(), | ||
}); | ||
return pipelineInstance; | ||
} | ||
|
||
public async getLocalModel(task: string, model: string): Promise<any> { | ||
const pipelineInstance = await pipeline(task as PipelineType, model, { | ||
local_files_only: true, | ||
revision: 'local', | ||
}); | ||
|
||
return pipelineInstance; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import { ChatConfig, ConfigLoader } from 'src/config/config-loader'; | ||
import { ModelDownloader } from './model-downloader'; | ||
|
||
export async function downloadAllModels(): Promise<void> { | ||
const configLoader = new ConfigLoader(); | ||
configLoader.validateConfig(); | ||
const chats = configLoader.get<ChatConfig[]>(''); | ||
const downloader = ModelDownloader.getInstance(); | ||
console.log('Loaded config:', chats); | ||
const loadPromises = chats.map(async (chatConfig: ChatConfig) => { | ||
const { model, task } = chatConfig; | ||
try { | ||
downloader.logger.log(model, task); | ||
const pipelineInstance = await downloader.downloadModel(task, model); | ||
} catch (error) { | ||
downloader.logger.error(`Failed to load model ${model}:`, error.message); | ||
} | ||
}); | ||
await Promise.all(loadPromises); | ||
|
||
downloader.logger.log('All models loaded.'); | ||
} |
Oops, something went wrong.