Skip to content

Commit

Permalink
resolve PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jezekra1 committed Jan 6, 2025
1 parent cf288f8 commit 0eaa1a2
Showing 12 changed files with 64 additions and 32 deletions.
4 changes: 2 additions & 2 deletions seeders/DatabaseSeeder.ts
Original file line number Diff line number Diff line change
@@ -92,7 +92,7 @@ export class DatabaseSeeder extends Seeder {
redactedValue: redactProjectKeyValue(PROJECT_API_KEY)
});
const beeAssistant = new Assistant({
model: defaultAIProvider.createChatLLM().modelId,
model: defaultAIProvider.createAssistantBackend().modelId,
agent: Agent.BEE,
tools: [
{
@@ -121,7 +121,7 @@ export class DatabaseSeeder extends Seeder {
}
});
const streamlitAssistant = new Assistant({
model: defaultAIProvider.createChatLLM().modelId,
model: defaultAIProvider.createAssistantBackend().modelId,
agent: Agent.STREAMLIT,
tools: [],
name: 'Builder Assistant',
2 changes: 1 addition & 1 deletion src/assistants/assistants.service.ts
Original file line number Diff line number Diff line change
@@ -114,7 +114,7 @@ export async function createAssistant({
description: description ?? undefined,
metadata,
topP: top_p ?? undefined,
model: model ?? defaultAIProvider.createChatLLM().modelId,
model: model ?? defaultAIProvider.createAssistantBackend().modelId,
agent,
temperature: temperature ?? undefined,
systemPromptOverwrite: system_prompt_overwrite ?? undefined
2 changes: 1 addition & 1 deletion src/chat/chat.service.ts
Original file line number Diff line number Diff line change
@@ -55,7 +55,7 @@ export async function createChatCompletion({
messages,
response_format
}: ChatCompletionCreateBody): Promise<ChatCompletionCreateResponse> {
const llm = defaultAIProvider.createChatLLM({ model });
const llm = defaultAIProvider.createChatBackend({ model });
const chat = new Chat({ model: llm.modelId, messages, responseFormat: response_format });
await ORM.em.persistAndFlush(chat);
try {
2 changes: 1 addition & 1 deletion src/runs/execution/execute.ts
Original file line number Diff line number Diff line change
@@ -122,7 +122,7 @@ export async function executeRun(run: LoadedRun) {
const context = { run, publish } as AgentContext;

const tools = await getTools(run, context);
const llm = defaultAIProvider.createChatLLM(run);
const llm = defaultAIProvider.createAssistantBackend(run);
const memory = new TokenMemory({ llm });
await memory.addMany(messages);

66 changes: 48 additions & 18 deletions src/runs/execution/provider.ts
Original file line number Diff line number Diff line change
@@ -50,9 +50,10 @@ interface AIProvider<
ChatLLMType extends ChatLLM<ChatLLMOutput>,
LLMType extends LLM<BaseLLMOutput> = any
> {
createChatLLM: (params?: ChatLLMParams) => ChatLLMType;
createCodeLLM: (params?: { model?: string }) => LLMType | void;
createEmbeddingModel?: (params?: { model?: string }) => EmbeddingModel;
createChatBackend: (params?: ChatLLMParams) => ChatLLMType;
createAssistantBackend: (params?: ChatLLMParams) => ChatLLMType;
createCodeBackend: (params?: { model?: string }) => LLMType | void;
createEmbeddingBackend?: (params?: { model?: string }) => EmbeddingModel;
}

export class BamAIProvider implements AIProvider<BAMChatLLM, BAMLLM> {
@@ -62,7 +63,10 @@ export class BamAIProvider implements AIProvider<BAMChatLLM, BAMLLM> {
BamAIProvider.client ??= new BAMClient({ apiKey: BAM_API_KEY ?? undefined });
}

createChatLLM({ model = 'meta-llama/llama-3-1-70b-instruct', ...params }: ChatLLMParams = {}) {
createChatBackend({
model = 'meta-llama/llama-3-1-70b-instruct',
...params
}: ChatLLMParams = {}) {
return BAMChatLLM.fromPreset(model as BAMChatLLMPresetModel, {
client: BamAIProvider.client,
parameters: (parameters) => ({
@@ -74,7 +78,11 @@ export class BamAIProvider implements AIProvider<BAMChatLLM, BAMLLM> {
});
}

createCodeLLM({ model = 'meta-llama/llama-3-1-70b-instruct' } = {}) {
createAssistantBackend(params?: ChatLLMParams) {
return this.createChatBackend(params);
}

createCodeBackend({ model = 'meta-llama/llama-3-1-70b-instruct' } = {}) {
return new BAMLLM({
client: BamAIProvider.client,
modelId: model,
@@ -86,7 +94,7 @@ export class BamAIProvider implements AIProvider<BAMChatLLM, BAMLLM> {
});
}

createEmbeddingModel({ model = 'baai/bge-large-en-v1.5' } = {}) {
createEmbeddingBackend({ model = 'baai/bge-large-en-v1.5' } = {}) {
return new BAMLLM({ client: BamAIProvider.client, modelId: model });
}
}
@@ -98,7 +106,7 @@ export class OllamaAIProvider implements AIProvider<OllamaChatLLM, OllamaLLM> {
OllamaAIProvider.client ??= new Ollama({ host: OLLAMA_URL ?? undefined });
}

createChatLLM({ model: modelId = 'llama3.1', ...params }: ChatLLMParams = {}) {
createChatBackend({ model: modelId = 'llama3.1', ...params }: ChatLLMParams = {}) {
return new OllamaChatLLM({
client: OllamaAIProvider.client,
modelId,
@@ -109,9 +117,13 @@ export class OllamaAIProvider implements AIProvider<OllamaChatLLM, OllamaLLM> {
}
});
}
createCodeLLM() {}
createAssistantBackend(params?: ChatLLMParams) {
return this.createChatBackend(params);
}

createEmbeddingModel({ model: modelId = 'nomic-embed-text' } = {}) {
createCodeBackend() {}

createEmbeddingBackend({ model: modelId = 'nomic-embed-text' } = {}) {
return new OllamaLLM({ client: OllamaAIProvider.client, modelId });
}
}
@@ -122,7 +134,7 @@ export class OpenAIProvider implements AIProvider<OpenAIChatLLM> {
OpenAIProvider.client ??= new OpenAI({ apiKey: OPENAI_API_KEY ?? undefined });
}

createChatLLM({ model = 'gpt-4o', ...params }: ChatLLMParams = {}) {
createChatBackend({ model = 'gpt-4o', ...params }: ChatLLMParams = {}) {
return new OpenAIChatLLM({
client: OpenAIProvider.client,
modelId: model as OpenAI.ChatModel,
@@ -134,9 +146,13 @@ export class OpenAIProvider implements AIProvider<OpenAIChatLLM> {
});
}

createCodeLLM() {}
createAssistantBackend(params?: ChatLLMParams) {
return this.createChatBackend(params);
}

createCodeBackend() {}

createEmbeddingModel({ model = 'text-embedding-3-large' } = {}) {
createEmbeddingBackend({ model = 'text-embedding-3-large' } = {}) {
return {
chatLLM: new OpenAIChatLLM({
client: OpenAIProvider.client,
@@ -170,7 +186,10 @@ export class IBMvLLMAIProvider implements AIProvider<IBMVllmChatLLM, IBMvLLM> {
});
}

createChatLLM({ model = IBMVllmModel.LLAMA_3_1_70B_INSTRUCT, ...params }: ChatLLMParams = {}) {
createChatBackend({
model = IBMVllmModel.LLAMA_3_1_70B_INSTRUCT,
...params
}: ChatLLMParams = {}) {
return IBMVllmChatLLM.fromPreset(model as IBMVllmChatLLMPresetModel, {
client: IBMvLLMAIProvider.client,
parameters: (parameters) => ({
@@ -188,7 +207,11 @@ export class IBMvLLMAIProvider implements AIProvider<IBMVllmChatLLM, IBMvLLM> {
});
}

createCodeLLM({ model: modelId = 'meta-llama/llama-3-1-70b-instruct' } = {}) {
createAssistantBackend(params?: ChatLLMParams) {
return this.createChatBackend(params);
}

createCodeBackend({ model: modelId = 'meta-llama/llama-3-1-70b-instruct' } = {}) {
return new IBMvLLM({
client: IBMvLLMAIProvider.client,
modelId,
@@ -199,7 +222,7 @@ export class IBMvLLMAIProvider implements AIProvider<IBMVllmChatLLM, IBMvLLM> {
});
}

createEmbeddingModel({ model: modelId = 'baai/bge-large-en-v1.5' } = {}) {
createEmbeddingBackend({ model: modelId = 'baai/bge-large-en-v1.5' } = {}) {
return new IBMvLLM({ client: IBMvLLMAIProvider.client, modelId });
}
}
@@ -215,7 +238,10 @@ export class WatsonxAIProvider implements AIProvider<WatsonXChatLLM, WatsonXLLM>
};
}

createChatLLM({ model = 'meta-llama/llama-3-1-70b-instruct', ...params }: ChatLLMParams = {}) {
createChatBackend({
model = 'meta-llama/llama-3-1-70b-instruct',
...params
}: ChatLLMParams = {}) {
return WatsonXChatLLM.fromPreset(model as WatsonXChatLLMPresetModel, {
...this.credentials,
parameters: (parameters) => ({
@@ -227,7 +253,11 @@ export class WatsonxAIProvider implements AIProvider<WatsonXChatLLM, WatsonXLLM>
});
}

createCodeLLM({ model: modelId = 'meta-llama/llama-3-1-70b-instruct' } = {}) {
createAssistantBackend(params?: ChatLLMParams) {
return this.createChatBackend(params);
}

createCodeBackend({ model: modelId = 'meta-llama/llama-3-1-70b-instruct' } = {}) {
return new WatsonXLLM({
...this.credentials,
modelId,
@@ -239,7 +269,7 @@ export class WatsonxAIProvider implements AIProvider<WatsonXChatLLM, WatsonXLLM>
});
}

createEmbeddingModel({ model: modelId = 'ibm/slate-30m-english-rtrvr-v2' } = {}) {
createEmbeddingBackend({ model: modelId = 'ibm/slate-30m-english-rtrvr-v2' } = {}) {
return new WatsonXLLM({ ...this.credentials, modelId, region: WATSONX_REGION ?? undefined });
}
}
6 changes: 4 additions & 2 deletions src/runs/execution/tools/file-search-tool.ts
Original file line number Diff line number Diff line change
@@ -94,9 +94,11 @@ export class FileSearchTool extends Tool<FileSearchToolOutput, FileSearchToolOpt
): Promise<FileSearchToolOutput> {
const vectorStoreClient = getVectorStoreClient();

const embeddingModel = defaultAIProvider.createEmbeddingModel();
const embeddingModel = defaultAIProvider.createEmbeddingBackend();

const [embedding] = (await embeddingModel.embed([query], { signal: run.signal })).embeddings;
const {
embeddings: [embedding]
} = await embeddingModel.embed([query], { signal: run.signal });

if (this.vectorStores.some((vectorStore) => vectorStore.expired)) {
throw new Error('Some of the vector stores are expired');
4 changes: 2 additions & 2 deletions src/runs/execution/tools/helpers.ts
Original file line number Diff line number Diff line change
@@ -134,7 +134,7 @@ export async function getTools(run: LoadedRun, context: AgentContext): Promise<F
(tool): tool is SystemUsage => tool.type === ToolType.SYSTEM && tool.toolId === SystemTools.LLM
);
if (llmUsage) {
tools.push(new LLMTool({ llm: defaultAIProvider.createChatLLM() }));
tools.push(new LLMTool({ llm: defaultAIProvider.createChatBackend() }));
}

const calculatorUsage = run.tools.find(
@@ -186,7 +186,7 @@ export async function getTools(run: LoadedRun, context: AgentContext): Promise<F
.flatMap((container) => container.file.$);

if (codeInterpreterUsage) {
const codeLLM = defaultAIProvider.createCodeLLM();
const codeLLM = defaultAIProvider.createCodeBackend();
tools.push(
new PythonTool({
codeInterpreter,
2 changes: 1 addition & 1 deletion src/runs/execution/tools/wikipedia-tool.ts
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ export function wikipediaTool(
maxResults = 5
): AnyTool {
// LLM to perform text embedding
const embeddingLLM = defaultAIProvider.createEmbeddingModel();
const embeddingLLM = defaultAIProvider.createEmbeddingBackend();

// Similarity tool to calculate the similarity between a query and a set of wikipedia passages
const similarity = new SimilarityTool({
2 changes: 1 addition & 1 deletion src/tools/tools.service.ts
Original file line number Diff line number Diff line change
@@ -465,7 +465,7 @@ function getSystemTools() {
});
const fileSearch = new FileSearchTool({ vectorStores: [], maxNumResults: 0 });
const readFile = new ReadFileTool({ files: [], fileSize: 0 });
const llmTool = new LLMTool({ llm: defaultAIProvider.createChatLLM() });
const llmTool = new LLMTool({ llm: defaultAIProvider.createChatBackend() });
const calculatorTool = new CalculatorTool();

const systemTools = new Map<string, SystemTool>();
2 changes: 1 addition & 1 deletion src/users/users.service.ts
Original file line number Diff line number Diff line change
@@ -131,7 +131,7 @@ export async function createUser({
user.defaultProject = ORM.em.getRepository(Project).getReference(project.id, { wrapped: true });

const assistant = new Assistant({
model: defaultAIProvider.createChatLLM().modelId,
model: defaultAIProvider.createAssistantBackend().modelId,
agent: Agent.BEE,
tools: [
new SystemUsage({ toolId: SystemTools.WEB_SEARCH }),
2 changes: 1 addition & 1 deletion src/vector-store-files/execution/client.ts
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ export type DocumentType = z.ZodType<Document>;

export function getVectorStoreClient(): VectorStoreClient<DocumentType> {
return new MilvusVectorStore({
modelName: defaultAIProvider.createEmbeddingModel().modelId,
modelName: defaultAIProvider.createEmbeddingBackend().modelId,
documentSchema: DocumentSchema
});
}
2 changes: 1 addition & 1 deletion src/vector-store-files/execution/process-file.ts
Original file line number Diff line number Diff line change
@@ -64,7 +64,7 @@ export async function processVectorStoreFile(vectorStoreFile: Loaded<VectorStore
}

async function* embedTransform(source: AsyncIterable<string[]>) {
const embeddingAdapter = defaultAIProvider.createEmbeddingModel();
const embeddingAdapter = defaultAIProvider.createEmbeddingBackend();
for await (const items of source) {
const output = await embeddingAdapter.embed(items, { signal: controller.signal });
yield output.embeddings.map((embedding, idx) => ({

0 comments on commit 0eaa1a2

Please sign in to comment.