Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): Backend config convention and pull model #54

Merged
merged 25 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7635a8a
feat(backend): backend config convention
NarwhalChen Oct 31, 2024
c6fde4a
feat: adding pull model function with test
NarwhalChen Nov 18, 2024
1deca3d
fix: adding default type
NarwhalChen Nov 28, 2024
e69b220
fix: fix the problem of prepare a unready config to other service
NarwhalChen Nov 28, 2024
618c20e
fix: File naming is more semantic
NarwhalChen Nov 30, 2024
f1382eb
Merge remote-tracking branch 'origin/main' into backend-config-conven…
NarwhalChen Dec 10, 2024
96a55d1
[autofix.ci] apply automated fixes
autofix-ci[bot] Dec 10, 2024
71b662d
fix: fixing bug in modeldowloader that cannot pull model
NarwhalChen Dec 10, 2024
a4f692d
fix: fixing bug in modeldowloader that cannot pull model
NarwhalChen Dec 10, 2024
458db40
Merge branch 'backend-config-convention' of https://github.com/Sma1lb…
NarwhalChen Dec 10, 2024
bf37046
[autofix.ci] apply automated fixes
autofix-ci[bot] Dec 10, 2024
5a6f4bf
to: makes model download when project start and download to local folder
NarwhalChen Dec 12, 2024
ed2e534
to: can load model both locally and remotely
NarwhalChen Dec 13, 2024
3ab6bc8
fix: fixing the layer structure of chatsconfig
NarwhalChen Dec 13, 2024
3fe86b4
fix: fixing the layer structure of chatsconfig and relative bug in Lo…
NarwhalChen Dec 13, 2024
a2c07a6
[autofix.ci] apply automated fixes
autofix-ci[bot] Dec 13, 2024
f97b97d
fix: fixing the layer structure of chatsconfig and relative bug in Lo…
NarwhalChen Dec 15, 2024
eb55281
Merge branch 'backend-config-convention' of https://github.com/Sma1lb…
NarwhalChen Dec 15, 2024
9c34f08
refactor: rename ConfigLoader and ModelLoader files, update imports, …
Sma1lboy Dec 15, 2024
ed07140
fix: updating layer structure of chatconfig and updating relative test
NarwhalChen Dec 15, 2024
8a311f0
Delete .editorconfig
Sma1lboy Dec 16, 2024
28d5f78
Delete pnpm-lock.yaml
Sma1lboy Dec 16, 2024
c1c8a98
Delete pnpm-lock.yaml
Sma1lboy Dec 16, 2024
3ea4222
Merge branch 'main' into backend-config-convention
Sma1lboy Dec 16, 2024
4e59695
Fix merge conflicts in pnpm-lock.yaml and update axios version to 1.7.8
Sma1lboy Dec 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
72 changes: 72 additions & 0 deletions backend/__tests__/loadAllChatsModels.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { ConfigLoader } from '../src/config/ConfigLoader';
import { ModelDownloader, getModel } from '../src/model/ModelDownloader';

const originalIsArray = Array.isArray;

Array.isArray = jest.fn((type: any): type is any[] => {
if (
type &&
type.constructor &&
(type.constructor.name === 'Float32Array' ||
type.constructor.name === 'BigInt64Array')
) {
return true;
}
return originalIsArray(type);
}) as unknown as (arg: any) => arg is any[];

jest.mock('../src/config/ConfigLoader', () => {
return {
ConfigLoader: jest.fn().mockImplementation(() => {
return {
get: jest.fn().mockReturnValue({
chat1: {
model: 'Xenova/LaMini-Flan-T5-783M',
task: 'text2text-generation',
},
}),
validateConfig: jest.fn(),
};
}),
};
});

describe('loadAllChatsModels with real model loading', () => {
beforeAll(async () => {
await ModelDownloader.downloadAllModels();
});

it('should load real models specified in config', async () => {
expect(ConfigLoader).toHaveBeenCalled();

const chat1Model = getModel('chat1');
expect(chat1Model).toBeDefined();
console.log('Loaded Model:', chat1Model);

Check warning on line 44 in backend/__tests__/loadAllChatsModels.spec.ts

View workflow job for this annotation

GitHub Actions / autofix

Unexpected console statement

expect(chat1Model).toHaveProperty('model');
expect(chat1Model).toHaveProperty('tokenizer');

try {
const chat1Output = await chat1Model(
'Write me a love poem about cheese.',
{
max_new_tokens: 200,
temperature: 0.9,
repetition_penalty: 2.0,
no_repeat_ngram_size: 3,
},
);

console.log('Model Output:', chat1Output);

Check warning on line 60 in backend/__tests__/loadAllChatsModels.spec.ts

View workflow job for this annotation

GitHub Actions / autofix

Unexpected console statement
expect(chat1Output).toBeDefined();
expect(chat1Output[0]).toHaveProperty('generated_text');
console.log(chat1Output[0].generated_text);

Check warning on line 63 in backend/__tests__/loadAllChatsModels.spec.ts

View workflow job for this annotation

GitHub Actions / autofix

Unexpected console statement
} catch (error) {
console.error('Error during model inference:', error);

Check warning on line 65 in backend/__tests__/loadAllChatsModels.spec.ts

View workflow job for this annotation

GitHub Actions / autofix

Unexpected console statement
}
}, 60000);
});

afterAll(() => {
Array.isArray = originalIsArray;
});
4 changes: 4 additions & 0 deletions backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
},
"dependencies": {
"@apollo/server": "^4.11.0",
"@huggingface/hub": "latest",
"@huggingface/transformers": "latest",
Comment on lines +29 to +30
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Pin Hugging Face dependencies to specific versions

Using "latest" for production dependencies is risky as it can lead to unexpected breaking changes. Consider pinning to specific versions.

- "@huggingface/hub": "latest",
- "@huggingface/transformers": "latest",
+ "@huggingface/hub": "^0.14.1",
+ "@huggingface/transformers": "^2.15.0",

Committable suggestion skipped: line range outside the PR's diff.

"@nestjs/apollo": "^12.2.0",
"@nestjs/axios": "^3.0.3",
"@nestjs/common": "^10.0.0",
Expand All @@ -36,11 +38,13 @@
"@nestjs/platform-express": "^10.0.0",
"@nestjs/typeorm": "^10.0.2",
"@types/bcrypt": "^5.0.2",
"axios": "^1.7.7",
"@types/fs-extra": "^11.0.4",
"bcrypt": "^5.1.1",
"class-validator": "^0.14.1",
"fs-extra": "^11.2.0",
"graphql": "^16.9.0",
"lodash": "^4.17.21",
"graphql-subscriptions": "^2.0.0",
"graphql-ws": "^5.16.0",
"markdown-to-txt": "^2.0.1",
Expand Down
444 changes: 438 additions & 6 deletions backend/pnpm-lock.yaml

Large diffs are not rendered by default.

48 changes: 48 additions & 0 deletions backend/src/config/ConfigLoader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import * as fs from 'fs';
import * as path from 'path';
import _ from 'lodash';
export interface ChatConfig {
model: string;
endpoint?: string;
Sma1lboy marked this conversation as resolved.
Show resolved Hide resolved
token?: string;
default?: boolean;
task?: string;
}
export class ConfigLoader {
private config: ChatConfig;

private readonly configPath: string;

constructor() {
this.configPath = path.resolve(__dirname, 'config.json');
this.loadConfig();
}

private loadConfig() {
const file = fs.readFileSync(this.configPath, 'utf-8');
this.config = JSON.parse(file);
}

get<T>(path: string) {
return _.get(this.config, path);
}

set(path: string, value: any) {
_.set(this.config, path, value);
this.saveConfig();
}

private saveConfig() {
fs.writeFileSync(
this.configPath,
JSON.stringify(this.config, null, 4),
'utf-8',
);
}

validateConfig() {
if (!this.config) {
throw new Error("Invalid configuration: 'chats' section is missing.");
}
}
}
63 changes: 63 additions & 0 deletions backend/src/model/ModelDownloader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import { Logger } from '@nestjs/common';
import { PipelineType, pipeline, env } from '@huggingface/transformers';
import { ConfigLoader, ChatConfig } from '../config/ConfigLoader';

type Progress = { loaded: number; total: number };
type ProgressCallback = (progress: Progress) => void;

env.allowLocalModels = false;

export class ModelDownloader {
private static readonly logger = new Logger(ModelDownloader.name);
private static readonly loadedModels = new Map<string, any>();

public static async downloadAllModels(
Sma1lboy marked this conversation as resolved.
Show resolved Hide resolved
progressCallback: ProgressCallback = () => {},
): Promise<void> {
const configLoader = new ConfigLoader();
configLoader.validateConfig();
const chats = configLoader.get<{ [key: string]: ChatConfig }>('chats');

const loadPromises = Object.entries(chats).map(
async ([chatKey, chatConfig]: [string, ChatConfig]) => {
const { model, task } = chatConfig;
try {
ModelDownloader.logger.log(`Starting to load model: ${model}`);
const pipelineInstance = await ModelDownloader.downloadModel(
task,
model,
progressCallback,
);
ModelDownloader.logger.log(`Model loaded successfully: ${model}`);
this.loadedModels.set(chatKey, pipelineInstance);
} catch (error) {
ModelDownloader.logger.error(
`Failed to load model ${model}:`,
error.message,
);
}
},
);

await Promise.all(loadPromises);

ModelDownloader.logger.log('All models loaded.');
}

private static async downloadModel(
task: string,
model: string,
progressCallback?: ProgressCallback,
): Promise<any> {
const pipelineOptions = progressCallback
? { progress_callback: progressCallback }
: undefined;
return pipeline(task as PipelineType, model, pipelineOptions);
}

public static getModel(chatKey: string): any {
return ModelDownloader.loadedModels.get(chatKey);
}
}

export const getModel = ModelDownloader.getModel;
Loading
Loading