Skip to content

Commit

Permalink
feat(llm): add max_new_tokens parameter (#10)
Browse files Browse the repository at this point in the history
Signed-off-by: Lukáš Janeček <[email protected]>
Co-authored-by: Lukáš Janeček <[email protected]>
  • Loading branch information
xjacka and Lukáš Janeček authored Oct 22, 2024
1 parent 850dcbd commit 9b67bea
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions src/runs/execution/factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ import {
WATSONX_REGION
} from '@/config';

const MAX_NEW_TOKENS = 4096;

export function getDefaultModel(backend: LLMBackend = LLM_BACKEND) {
switch (backend) {
case LLMBackend.IBM_VLLM:
Expand Down Expand Up @@ -95,6 +97,10 @@ export function createChatLLM(run: Loaded<Run>, backend: LLMBackend = LLM_BACKEN
...parameters.sampling,
top_p: run.topP ?? parameters.sampling?.top_p,
temperature: run.temperature ?? parameters.sampling?.temperature
},
stopping: {
...parameters.stopping,
max_new_tokens: MAX_NEW_TOKENS
}
})
});
Expand All @@ -106,7 +112,8 @@ export function createChatLLM(run: Loaded<Run>, backend: LLMBackend = LLM_BACKEN
modelId: run.model,
parameters: {
top_p: run.topP,
temperature: run.temperature
temperature: run.temperature,
num_predict: MAX_NEW_TOKENS
}
});
}
Expand All @@ -117,7 +124,8 @@ export function createChatLLM(run: Loaded<Run>, backend: LLMBackend = LLM_BACKEN
modelId: run.model as OpenAI.ChatModel,
parameters: {
top_p: run.topP,
temperature: run.temperature
temperature: run.temperature,
max_completion_tokens: MAX_NEW_TOKENS
}
});
}
Expand All @@ -128,7 +136,8 @@ export function createChatLLM(run: Loaded<Run>, backend: LLMBackend = LLM_BACKEN
parameters: (parameters) => ({
...parameters,
top_p: run.topP ?? parameters.top_p,
temperature: run.temperature ?? parameters.temperature
temperature: run.temperature ?? parameters.temperature,
max_new_tokens: MAX_NEW_TOKENS
})
});
}
Expand All @@ -142,7 +151,8 @@ export function createChatLLM(run: Loaded<Run>, backend: LLMBackend = LLM_BACKEN
parameters: (parameters) => ({
...parameters,
top_p: run.topP ?? parameters.top_p,
temperature: run.temperature ?? parameters.temperature
temperature: run.temperature ?? parameters.temperature,
max_new_tokens: MAX_NEW_TOKENS
})
});
}
Expand All @@ -168,7 +178,10 @@ export function createCodeLLM(backend: LLMBackend = LLM_BACKEND) {
return new IBMvLLM({
client: vllmClient,
modelId: 'ibm/granite-34b-code-instruct',
parameters: { method: 'GREEDY', stopping: { include_stop_sequence: false } }
parameters: {
method: 'GREEDY',
stopping: { include_stop_sequence: false, max_new_tokens: MAX_NEW_TOKENS }
}
});
}
case LLMBackend.BAM: {
Expand All @@ -178,7 +191,8 @@ export function createCodeLLM(backend: LLMBackend = LLM_BACKEND) {
modelId: 'ibm/granite-34b-code-instruct',
parameters: {
decoding_method: 'greedy',
include_stop_sequence: false
include_stop_sequence: false,
max_new_tokens: MAX_NEW_TOKENS
}
});
}
Expand All @@ -193,7 +207,7 @@ export function createCodeLLM(backend: LLMBackend = LLM_BACKEND) {
parameters: {
decoding_method: 'greedy',
include_stop_sequence: false,
max_new_tokens: 2048
max_new_tokens: MAX_NEW_TOKENS
}
});
}
Expand Down

0 comments on commit 9b67bea

Please sign in to comment.