Skip to content

Commit

Permalink
feat: Provide typed model parameters in LLM Module Config (#424)
Browse files Browse the repository at this point in the history
* make model_params optional

* changeset

* add convenience function

* public api

* fix: Changes from lint

* type def and update readme

* remove comment

* remove all cases of empty model_params

* fix type test

* conflict fix

* changelog text

* remove empty model_param from Readme

---------

Co-authored-by: cloud-sdk-js <[email protected]>
Co-authored-by: Deeksha Sinha <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2025
1 parent 1da2caa commit 1476584
Show file tree
Hide file tree
Showing 14 changed files with 58 additions and 45 deletions.
5 changes: 5 additions & 0 deletions .changeset/sharp-pianos-press.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@sap-ai-sdk/orchestration': minor
---

[Improvement] Refine the type definition of the `model_params` property in the `LlmModuleConfig` to also include known properties.
12 changes: 5 additions & 7 deletions packages/orchestration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ Consequently, each orchestration deployment uniquely maps to a resource group wi
## Usage

Leverage the orchestration service capabilities by using the orchestration client.
Configure the LLM module by setting the `model_name` and `model_params` properties.
Configure the LLM module by setting the `model_name` property.
Define the optional `model_version` property to choose an available model version.
By default, the version is set to `latest`.
Specify the optional `model_params` property to apply specific parameters to the model

```ts
import { OrchestrationClient } from '@sap-ai-sdk/orchestration';
Expand Down Expand Up @@ -203,8 +204,7 @@ import { OrchestrationClient } from '@sap-ai-sdk/orchestration';

const orchestrationClient = new OrchestrationClient({
llm: {
model_name: 'gpt-4o',
model_params: {}
model_name: 'gpt-4o'
},
templating: {
template: [
Expand Down Expand Up @@ -299,8 +299,7 @@ You can anonymize or pseudonomize the prompt using the data masking capabilities
```ts
const orchestrationClient = new OrchestrationClient({
llm: {
model_name: 'gpt-4o',
model_params: {}
model_name: 'gpt-4o'
},
templating: {
template: [
Expand Down Expand Up @@ -335,8 +334,7 @@ Grounding enables integrating external, contextually relevant, domain-specific,
```ts
const orchestrationClient = new OrchestrationClient({
llm: {
model_name: 'gpt-35-turbo',
model_params: {}
model_name: 'gpt-35-turbo'
},
templating: {
template: [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export type LlmModuleConfig = {
* "n": 2
* }
*/
model_params: Record<string, any>;
model_params?: Record<string, any>;
/**
* Version of the model to use
* Default: "latest".
Expand Down
5 changes: 3 additions & 2 deletions packages/orchestration/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ export type {
export type {
OrchestrationModuleConfig,
LlmModuleConfig,
Prompt
Prompt,
LlmModelParams
} from './orchestration-types.js';

export { OrchestrationClient } from './orchestration-client.js';

export { buildAzureContentFilter } from './orchestration-filter-utility.js';
export { buildAzureContentFilter } from './orchestration-utils.js';

export { OrchestrationResponse } from './orchestration-response.js';

Expand Down
2 changes: 1 addition & 1 deletion packages/orchestration/src/internal.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export * from './orchestration-client.js';
export * from './orchestration-filter-utility.js';
export * from './orchestration-utils.js';
export * from './orchestration-types.js';
export * from './orchestration-response.js';
8 changes: 3 additions & 5 deletions packages/orchestration/src/orchestration-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
constructCompletionPostRequest,
OrchestrationClient
} from './orchestration-client.js';
import { buildAzureContentFilter } from './orchestration-filter-utility.js';
import { buildAzureContentFilter } from './orchestration-utils.js';
import { OrchestrationResponse } from './orchestration-response.js';
import type { CompletionPostResponse } from './client/api/schema/index.js';
import type {
Expand Down Expand Up @@ -268,8 +268,7 @@ describe('orchestration service client', () => {
it('calls chatCompletion with grounding configuration', async () => {
const config: OrchestrationModuleConfig = {
llm: {
model_name: 'gpt-35-turbo',
model_params: {}
model_name: 'gpt-35-turbo'
},
templating: {
template: [
Expand Down Expand Up @@ -337,8 +336,7 @@ describe('orchestration service client', () => {

const config: OrchestrationModuleConfig = {
llm: {
model_name: 'gpt-4o',
model_params: {}
model_name: 'gpt-4o'
},
templating: {
template: [{ role: 'user', content: "What's my name?" }]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { constructCompletionPostRequest } from './orchestration-client.js';
import { buildAzureContentFilter } from './orchestration-filter-utility.js';
import { buildAzureContentFilter } from './orchestration-utils.js';
import type { CompletionPostRequest } from './client/api/schema';
import type { OrchestrationModuleConfig } from './orchestration-types.js';

Expand Down
13 changes: 13 additions & 0 deletions packages/orchestration/src/orchestration-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,21 @@ export interface Prompt {
export type LlmModuleConfig = OriginalLlmModuleConfig & {
/** */
model_name: ChatModel;
model_params?: LlmModelParams;
};

/**
* Model Parameters for LLM module configuration.
*/
export type LlmModelParams = {
max_tokens?: number;
temperature?: number;
frequency_penalty?: number;
presence_penalty?: number;
top_p?: number;
n?: number;
} & Record<string, any>;

/**
* Orchestration module configuration.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { constructCompletionPostRequest } from './orchestration-client.js';
import { buildAzureContentFilter } from './orchestration-filter-utility.js';
import { buildAzureContentFilter } from './orchestration-utils.js';
import type {
CompletionPostRequest,
FilteringModuleConfig
} from './client/api/schema';
} from './client/api/schema/index.js';
import type { OrchestrationModuleConfig } from './orchestration-types.js';

describe('filter utility', () => {
Expand Down
1 change: 0 additions & 1 deletion packages/orchestration/src/spec/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ components:
type: object
required:
- model_name
- model_params
additionalProperties: false
properties:
model_name:
Expand Down
3 changes: 1 addition & 2 deletions sample-cap/srv/orchestration/orchestration-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ export default class OrchestrationService {
async chatCompletion(req: any) {
const { template, inputParams } = req.data;
const llm = {
model_name: 'gpt-4-32k',
model_params: {}
model_name: 'gpt-4-32k'
};
const templating = { template };

Expand Down
9 changes: 3 additions & 6 deletions sample-code/src/orchestration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ export async function orchestrationChatCompletion(): Promise<OrchestrationRespon
const orchestrationClient = new OrchestrationClient({
// define the language model to be used
llm: {
model_name: 'gpt-4o',
model_params: {}
model_name: 'gpt-4o'
},
// define the prompt
templating: {
Expand All @@ -41,8 +40,7 @@ export async function orchestrationChatCompletion(): Promise<OrchestrationRespon
}

const llm: LlmModuleConfig = {
model_name: 'gpt-4o',
model_params: {}
model_name: 'gpt-4o'
};

/**
Expand Down Expand Up @@ -160,8 +158,7 @@ export async function orchestrationCompletionMasking(): Promise<
> {
const orchestrationClient = new OrchestrationClient({
llm: {
model_name: 'gpt-4-32k',
model_params: {}
model_name: 'gpt-4-32k'
},
templating: {
template: [
Expand Down
37 changes: 20 additions & 17 deletions tests/type-tests/test/orchestration.test-d.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { expectError, expectType } from 'tsd';
import { expectError, expectType, expectAssignable } from 'tsd';
import {
OrchestrationClient,
CompletionPostResponse,
OrchestrationResponse,
TokenUsage,
ChatModel
ChatModel,
LlmModelParams
} from '@sap-ai-sdk/orchestration';

/**
Expand All @@ -16,8 +17,7 @@ expectType<Promise<OrchestrationResponse>>(
template: [{ role: 'user', content: 'Hello!' }]
},
llm: {
model_name: 'gpt-35-turbo-16k',
model_params: {}
model_name: 'gpt-35-turbo-16k'
}
}).chatCompletion()
);
Expand All @@ -29,8 +29,7 @@ expectType<CompletionPostResponse>(
template: [{ role: 'user', content: 'Hello!' }]
},
llm: {
model_name: 'gpt-35-turbo-16k',
model_params: {}
model_name: 'gpt-35-turbo-16k'
}
}).chatCompletion()
).data
Expand All @@ -43,8 +42,7 @@ expectType<string | undefined>(
template: [{ role: 'user', content: 'Hello!' }]
},
llm: {
model_name: 'gpt-35-turbo-16k',
model_params: {}
model_name: 'gpt-35-turbo-16k'
}
}).chatCompletion()
).getContent()
Expand All @@ -57,8 +55,7 @@ expectType<string | undefined>(
template: [{ role: 'user', content: 'Hello!' }]
},
llm: {
model_name: 'gpt-35-turbo-16k',
model_params: {}
model_name: 'gpt-35-turbo-16k'
}
}).chatCompletion()
).getFinishReason()
Expand All @@ -71,8 +68,7 @@ expectType<TokenUsage>(
template: [{ role: 'user', content: 'Hello!' }]
},
llm: {
model_name: 'gpt-35-turbo-16k',
model_params: {}
model_name: 'gpt-35-turbo-16k'
}
}).chatCompletion()
).getTokenUsage()
Expand All @@ -85,8 +81,7 @@ expectType<Promise<OrchestrationResponse>>(
template: [{ role: 'user', content: 'Hello!' }]
},
llm: {
model_name: 'gpt-35-turbo-16k',
model_params: {}
model_name: 'gpt-35-turbo-16k'
}
},
{
Expand Down Expand Up @@ -197,8 +192,7 @@ expectError<any>(new OrchestrationClient({}).chatCompletion());
expectError<any>(
new OrchestrationClient({
llm: {
model_name: 'gpt-35-turbo-16k',
model_params: {}
model_name: 'gpt-35-turbo-16k'
}
}).chatCompletion()
);
Expand All @@ -212,11 +206,20 @@ expectError<any>(
template: [{ role: 'user', content: 'Hello!' }]
},
llm: {
model_params: {}
model_params: { max_tokens: 50 }
}
}).chatCompletion()
);

/**
* Model parameters should accept known typed parameters and arbitrary parameters.
*/
expectAssignable<LlmModelParams>({
max_tokens: 50,
temperature: 0.2,
random_property: 'random - value'
});

/**
* Model parameters should adhere to OrchestrationCompletionParameters.// Todo: Check if additional checks can be added for model_params.
*/
Expand Down

0 comments on commit 1476584

Please sign in to comment.