Skip to content

Commit

Permalink
Merge pull request #7 from Sma1lboy/feat-llama-chat
Browse files Browse the repository at this point in the history
feat(chat): Feat llama chat
  • Loading branch information
ZHallen122 authored Oct 22, 2024
2 parents 28ae88a + 7a29b99 commit d7f3716
Show file tree
Hide file tree
Showing 22 changed files with 4,483 additions and 3 deletions.
7 changes: 5 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
node_modules/
.turbo/
*/**.turbo/
*/**/node_modules
*/**/dist
# temp model
*/**/models
1 change: 1 addition & 0 deletions backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"dependencies": {
"@apollo/server": "^4.11.0",
"@nestjs/apollo": "^12.2.0",
"@nestjs/axios": "^3.0.3",
"@nestjs/common": "^10.0.0",
"@nestjs/config": "^3.2.3",
"@nestjs/core": "^10.0.0",
Expand Down
43 changes: 43 additions & 0 deletions backend/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions backend/src/app.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { TokenModule } from './token/token.module';
import { ConfigModule, ConfigService } from '@nestjs/config';
import { JwtModule } from '@nestjs/jwt';
import { JwtCacheService } from './auth/jwt-cache.service';
import { ChatModule } from './chat/chat.module';

@Module({
imports: [
Expand All @@ -35,6 +36,7 @@ import { JwtCacheService } from './auth/jwt-cache.service';
AuthModule,
ProjectModule,
TokenModule,
ChatModule,
],
providers: [AppService],
})
Expand Down
45 changes: 45 additions & 0 deletions backend/src/chat/chat.model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { Field, InputType, ObjectType } from '@nestjs/graphql';

@ObjectType('ChatCompletionDeltaType')
class ChatCompletionDelta {
@Field({ nullable: true })
content?: string;
}
@ObjectType('ChatCompletionChunkType')
export class ChatCompletionChunk {
@Field()
id: string;

@Field()
object: string;

@Field()
created: number;

@Field()
model: string;

@Field({ nullable: true })
system_fingerprint: string | null;

@Field(() => [ChatCompletionChoice])
choices: ChatCompletionChoice[];
}

@ObjectType('ChatCompletionChoiceType')
class ChatCompletionChoice {
@Field()
index: number;

@Field(() => ChatCompletionDelta)
delta: ChatCompletionDelta;

@Field({ nullable: true })
finish_reason: string | null;
}

@InputType('ChatInputType')
export class ChatInput {
@Field()
message: string;
}
12 changes: 12 additions & 0 deletions backend/src/chat/chat.module.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { Module } from '@nestjs/common';
import { GraphQLModule } from '@nestjs/graphql';
import { ApolloDriver, ApolloDriverConfig } from '@nestjs/apollo';
import { HttpModule } from '@nestjs/axios';
import { ChatResolver } from './chat.resolver';
import { ChatProxyService } from './chat.service';

@Module({
imports: [HttpModule],
providers: [ChatResolver, ChatProxyService],
})
export class ChatModule {}
27 changes: 27 additions & 0 deletions backend/src/chat/chat.resolver.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { Resolver, Subscription, Args } from '@nestjs/graphql';
import { ChatCompletionChunk, ChatInput } from './chat.model';
import { ChatProxyService } from './chat.service';

@Resolver('Chat')
export class ChatResolver {
constructor(private chatProxyService: ChatProxyService) {}

@Subscription(() => ChatCompletionChunk, {
nullable: true,
resolve: (value) => value,
})
async *chatStream(@Args('input') input: ChatInput) {
const iterator = this.chatProxyService.streamChat(input.message);

try {
for await (const chunk of iterator) {
if (chunk) {
yield chunk;
}
}
} catch (error) {
console.error('Error in chatStream:', error);
throw new Error('Chat stream failed');
}
}
}
134 changes: 134 additions & 0 deletions backend/src/chat/chat.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import { Injectable, Logger } from '@nestjs/common';
import { HttpService } from '@nestjs/axios';
import { ChatCompletionChunk } from './chat.model';

type CustomAsyncIterableIterator<T> = AsyncIterator<T> & {
[Symbol.asyncIterator](): AsyncIterableIterator<T>;
};

@Injectable()
export class ChatProxyService {
private readonly logger = new Logger('ChatProxyService');

constructor(private httpService: HttpService) {}

streamChat(input: string): CustomAsyncIterableIterator<ChatCompletionChunk> {
this.logger.debug('request chat input: ' + input);

let isDone = false;
let responseSubscription: any;
const chunkQueue: ChatCompletionChunk[] = [];
let resolveNextChunk:
| ((value: IteratorResult<ChatCompletionChunk>) => void)
| null = null;

const iterator: CustomAsyncIterableIterator<ChatCompletionChunk> = {
next: () => {
return new Promise<IteratorResult<ChatCompletionChunk>>((resolve) => {
if (chunkQueue.length > 0) {
resolve({ done: false, value: chunkQueue.shift()! });
} else if (isDone) {
resolve({ done: true, value: undefined });
} else {
resolveNextChunk = resolve;
}
});
},
return: () => {
isDone = true;
if (responseSubscription) {
responseSubscription.unsubscribe();
}
return Promise.resolve({ done: true, value: undefined });
},
throw: (error) => {
isDone = true;
if (responseSubscription) {
responseSubscription.unsubscribe();
}
return Promise.reject(error);
},
[Symbol.asyncIterator]() {
return this;
},
};

responseSubscription = this.httpService
.post(
'http://localhost:3001/chat/completion',
{ content: input },
{ responseType: 'stream' },
)
.subscribe({
next: (response) => {
let buffer = '';
response.data.on('data', (chunk: Buffer) => {
buffer += chunk.toString();
let newlineIndex;
while ((newlineIndex = buffer.indexOf('\n')) !== -1) {
const line = buffer.slice(0, newlineIndex).trim();
buffer = buffer.slice(newlineIndex + 1);
if (line.startsWith('data: ')) {
const jsonStr = line.slice(6);
if (jsonStr === '[DONE]') {
isDone = true;
if (resolveNextChunk) {
resolveNextChunk({ done: true, value: undefined });
resolveNextChunk = null;
}
return;
}
try {
const parsedChunk: ChatCompletionChunk = JSON.parse(jsonStr);
if (this.isValidChunk(parsedChunk)) {
if (resolveNextChunk) {
resolveNextChunk({ done: false, value: parsedChunk });
resolveNextChunk = null;
} else {
chunkQueue.push(parsedChunk);
}
} else {
this.logger.warn('Invalid chunk received:', parsedChunk);
}
} catch (error) {
this.logger.error('Error parsing chunk:', error);
}
}
}
});
response.data.on('end', () => {
this.logger.debug('Stream ended');
isDone = true;
if (resolveNextChunk) {
resolveNextChunk({ done: true, value: undefined });
resolveNextChunk = null;
}
});
},
error: (error) => {
this.logger.error('Error in stream:', error);
if (resolveNextChunk) {
resolveNextChunk({ done: true, value: undefined });
resolveNextChunk = null;
}
},
});

return iterator;
}

private isValidChunk(chunk: any): chunk is ChatCompletionChunk {
return (
chunk &&
typeof chunk.id === 'string' &&
typeof chunk.object === 'string' &&
typeof chunk.created === 'number' &&
typeof chunk.model === 'string' &&
Array.isArray(chunk.choices) &&
chunk.choices.length > 0 &&
typeof chunk.choices[0].index === 'number' &&
chunk.choices[0].delta &&
typeof chunk.choices[0].delta.content === 'string'
);
}
}
Loading

0 comments on commit d7f3716

Please sign in to comment.