From 59e6704cfeddbefffad90169ad0fecd4d53bb1f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Jane=C4=8Dek?= Date: Fri, 18 Oct 2024 14:57:02 +0200 Subject: [PATCH] feat(WatsonX): add region environment variable (#8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Lukáš Janeček Co-authored-by: Lukáš Janeček --- .env.example | 1 + src/config.ts | 1 + src/embedding/factory.ts | 6 ++++-- src/runs/execution/factory.ts | 5 ++++- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.env.example b/.env.example index 6516908..2348a1a 100644 --- a/.env.example +++ b/.env.example @@ -58,6 +58,7 @@ OPENAI_API_KEY= # https://www.ibm.com/products/watsonx-ai WATSONX_API_KEY= WATSONX_PROJECT_ID= +WATSONX_REGION= # optional BAM_API_KEY= diff --git a/src/config.ts b/src/config.ts index 0b38502..3c0d2be 100644 --- a/src/config.ts +++ b/src/config.ts @@ -96,6 +96,7 @@ export const BAM_API_KEY = getEnv('BAM_API_KEY', null); export const WATSONX_API_KEY = getEnv('WATSONX_API_KEY', null); export const WATSONX_PROJECT_ID = getEnv('WATSONX_PROJECT_ID', null); +export const WATSONX_REGION = getEnv('WATSONX_REGION', null); export const CAIKIT_URL = getEnv('CAIKIT_URL', null); export const CAIKIT_CA_CERT = getEnv('CAIKIT_CA_CERT', null); diff --git a/src/embedding/factory.ts b/src/embedding/factory.ts index b245f50..9b5992e 100644 --- a/src/embedding/factory.ts +++ b/src/embedding/factory.ts @@ -34,7 +34,8 @@ import { OLLAMA_URL, OPENAI_API_KEY, WATSONX_API_KEY, - WATSONX_PROJECT_ID + WATSONX_PROJECT_ID, + WATSONX_REGION } from '@/config'; export function getDefaultEmbeddingModel(backend: EmbeddingBackend = EMBEDDING_BACKEND) { @@ -81,7 +82,8 @@ export async function createEmbeddingAdapter( const llm = new WatsonXLLM({ modelId: 'foobar', apiKey: WATSONX_API_KEY, - projectId: WATSONX_PROJECT_ID + projectId: WATSONX_PROJECT_ID, + region: WATSONX_REGION ?? undefined }); // @ts-expect-error use protected property const client = llm.client; diff --git a/src/runs/execution/factory.ts b/src/runs/execution/factory.ts index e6185cd..57efc67 100644 --- a/src/runs/execution/factory.ts +++ b/src/runs/execution/factory.ts @@ -48,7 +48,8 @@ import { OLLAMA_URL, OPENAI_API_KEY, WATSONX_API_KEY, - WATSONX_PROJECT_ID + WATSONX_PROJECT_ID, + WATSONX_REGION } from '@/config'; export function getDefaultModel(backend: LLMBackend = LLM_BACKEND) { @@ -137,6 +138,7 @@ export function createChatLLM(run: Loaded, backend: LLMBackend = LLM_BACKEN return WatsonXChatLLM.fromPreset(run.model as WatsonXChatLLMPresetModel, { apiKey: WATSONX_API_KEY, projectId: WATSONX_PROJECT_ID, + region: WATSONX_REGION ?? undefined, parameters: (parameters) => ({ ...parameters, top_p: run.topP ?? parameters.top_p, @@ -187,6 +189,7 @@ export function createCodeLLM(backend: LLMBackend = LLM_BACKEND) { modelId: 'meta-llama/llama-3-1-70b-instruct', apiKey: WATSONX_API_KEY, projectId: WATSONX_PROJECT_ID, + region: WATSONX_REGION ?? undefined, parameters: { decoding_method: 'greedy', include_stop_sequence: false,