Skip to content

Commit

Permalink
feat(WatsonX): add region environment variable (#8)
Browse files Browse the repository at this point in the history
Signed-off-by: Lukáš Janeček <[email protected]>
Co-authored-by: Lukáš Janeček <[email protected]>
  • Loading branch information
xjacka and Lukáš Janeček authored Oct 18, 2024
1 parent fcd47f1 commit 59e6704
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 3 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ OPENAI_API_KEY=
# https://www.ibm.com/products/watsonx-ai
WATSONX_API_KEY=
WATSONX_PROJECT_ID=
WATSONX_REGION= # optional

BAM_API_KEY=

Expand Down
1 change: 1 addition & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ export const BAM_API_KEY = getEnv('BAM_API_KEY', null);

export const WATSONX_API_KEY = getEnv('WATSONX_API_KEY', null);
export const WATSONX_PROJECT_ID = getEnv('WATSONX_PROJECT_ID', null);
export const WATSONX_REGION = getEnv('WATSONX_REGION', null);

export const CAIKIT_URL = getEnv('CAIKIT_URL', null);
export const CAIKIT_CA_CERT = getEnv('CAIKIT_CA_CERT', null);
Expand Down
6 changes: 4 additions & 2 deletions src/embedding/factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ import {
OLLAMA_URL,
OPENAI_API_KEY,
WATSONX_API_KEY,
WATSONX_PROJECT_ID
WATSONX_PROJECT_ID,
WATSONX_REGION
} from '@/config';

export function getDefaultEmbeddingModel(backend: EmbeddingBackend = EMBEDDING_BACKEND) {
Expand Down Expand Up @@ -81,7 +82,8 @@ export async function createEmbeddingAdapter(
const llm = new WatsonXLLM({
modelId: 'foobar',
apiKey: WATSONX_API_KEY,
projectId: WATSONX_PROJECT_ID
projectId: WATSONX_PROJECT_ID,
region: WATSONX_REGION ?? undefined
});
// @ts-expect-error use protected property
const client = llm.client;
Expand Down
5 changes: 4 additions & 1 deletion src/runs/execution/factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ import {
OLLAMA_URL,
OPENAI_API_KEY,
WATSONX_API_KEY,
WATSONX_PROJECT_ID
WATSONX_PROJECT_ID,
WATSONX_REGION
} from '@/config';

export function getDefaultModel(backend: LLMBackend = LLM_BACKEND) {
Expand Down Expand Up @@ -137,6 +138,7 @@ export function createChatLLM(run: Loaded<Run>, backend: LLMBackend = LLM_BACKEN
return WatsonXChatLLM.fromPreset(run.model as WatsonXChatLLMPresetModel, {
apiKey: WATSONX_API_KEY,
projectId: WATSONX_PROJECT_ID,
region: WATSONX_REGION ?? undefined,
parameters: (parameters) => ({
...parameters,
top_p: run.topP ?? parameters.top_p,
Expand Down Expand Up @@ -187,6 +189,7 @@ export function createCodeLLM(backend: LLMBackend = LLM_BACKEND) {
modelId: 'meta-llama/llama-3-1-70b-instruct',
apiKey: WATSONX_API_KEY,
projectId: WATSONX_PROJECT_ID,
region: WATSONX_REGION ?? undefined,
parameters: {
decoding_method: 'greedy',
include_stop_sequence: false,
Expand Down

0 comments on commit 59e6704

Please sign in to comment.