Skip to content

Commit

Permalink
feat: allow to pass extra data to upstreams endpoints (#470)
Browse files Browse the repository at this point in the history
  • Loading branch information
alekseyvdovenko authored Sep 6, 2024
1 parent 08e4699 commit 9f68c8c
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 31 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ Dynamic settings can include the following parameters:
| models.<model_name>.limits | `maxPromptTokens`: maximum number of tokens in a completion request.<br />`maxCompletionTokens`: maximum number of tokens in a completion response.<br />`maxTotalTokens`: maximum number of tokens in completion request and response combined.<br />Typically either `maxTotalTokens` is specified or `maxPromptTokens` and `maxCompletionTokens`. |
| models.<model_name>.pricing | `unit`: the pricing units (currently `token` and `char_without_whitespace` are supported).<br />`prompt`: per-unit price for the completion request in USD.<br />`completion`: per-unit price for the completion response in USD. |
| models.<model_name>.features | `rateEndpoint`: endpoint for rate requests *(exposed by core as `<deployment name>/rate`)*.<br />`tokenizeEndpoint`: endpoint for requests to the model tokenizer *(exposed by core as `<deployment name>/tokenize`)*.<br />`truncatePromptEndpoint`: endpoint for truncating prompt requests *(exposed by core as `<deployment name>/truncate_prompt`)*.<br />`systemPromptSupported`: does the model support system prompt (default is `true`).<br />`toolsSupported`: does the model support tools (default is `false`).<br />`seedSupported`: does the model support `seed` request parameter (default is `false`).<br />`urlAttachmentsSupported`: does the model/application support attachments with URLs (default is `false`).<br />`folderAttachmentsSupported`: does the model/application support folder attachments (default is `false`) |
| models.<model_name>.upstreams | `endpoint`: Model endpoint.<br />`key`: Your API key.<br />`weight`: Weight for upstream endpoint; positive number represents an endpoint capacity, zero or negative disables this enpoint from routing. Default value: 1.<br />`tier`: Specifies tier group for the endpoint. Only positive numbers allowed. All requests will be routed to the endpoints with the highest tier (the lowest tier value), other endpoints (with lower tier/higher tier value) may be used only if the highest tier endpoints are unavailable. Default value: 0 - highest tier. Refer to [Load Balancer](https://docs.epam-rail.com/tutorials/load-balancer) to learn more. |
| models.<model_name>.upstreams | `endpoint`: Model endpoint.<br />`key`: Your API key.<br />`weight`: Weight for upstream endpoint; positive number represents an endpoint capacity, zero or negative disables this enpoint from routing. Default value: 1.<br />`tier`: Specifies tier group for the endpoint. Only positive numbers allowed. All requests will be routed to the endpoints with the highest tier (the lowest tier value), other endpoints (with lower tier/higher tier value) may be used only if the highest tier endpoints are unavailable. Default value: 0 - highest tier. Refer to [Load Balancer](https://docs.epam-rail.com/tutorials/load-balancer) to learn more.<br/>`extraData`: Additional metadata containing any information that is passed to the upstream's endpoint. It can be a JSON or String. |
| models.<model_name>.defaults | Default parameters are applied if a request doesn't contain them in OpenAI `chat/completions` API call |
| models.<model_name>.interceptors | A list of interceptors to be triggered for the given model. Refer to [Interceptors](https://docs.epam-rail.com/tutorials/interceptors) to learn more. |
| keys | API Keys parameters:<br />`<core_key>`: Your API key. Refer to [API Keys](https://github.com/epam/ai-dial/blob/main/docs/Roles%20and%20Access%20Control/3.API%20Keys.md) to learn more. |
Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/epam/aidial/core/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class Proxy implements Handler<HttpServerRequest> {
public static final String HEADER_CONVERSATION_ID = "X-CONVERSATION-ID";
public static final String HEADER_UPSTREAM_ENDPOINT = "X-UPSTREAM-ENDPOINT";
public static final String HEADER_UPSTREAM_KEY = "X-UPSTREAM-KEY";
public static final String HEADER_UPSTREAM_EXTRA_DATA = "X-UPSTREAM-EXTRA-DATA";
public static final String HEADER_UPSTREAM_ATTEMPTS = "X-UPSTREAM-ATTEMPTS";
public static final String HEADER_CONTENT_TYPE_APPLICATION_JSON = "application/json";

Expand Down
4 changes: 4 additions & 0 deletions src/main/java/com/epam/aidial/core/config/Upstream.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.epam.aidial.core.config;

import com.epam.aidial.core.util.JsonToStringDeserializer;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
Expand All @@ -13,6 +15,8 @@ public class Upstream {

private String endpoint;
private String key;
@JsonDeserialize(using = JsonToStringDeserializer.class)
private String extraData;
private int weight = 1;
private int tier = 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ void handleProxyRequest(HttpClientRequest proxyRequest) {
Upstream upstream = context.getUpstreamRoute().get();
proxyRequest.putHeader(Proxy.HEADER_UPSTREAM_ENDPOINT, upstream.getEndpoint());
proxyRequest.putHeader(Proxy.HEADER_UPSTREAM_KEY, upstream.getKey());
proxyRequest.putHeader(Proxy.HEADER_UPSTREAM_EXTRA_DATA, upstream.getExtraData());
}

Buffer requestBody = context.getRequestBody();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.epam.aidial.core.util;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;

import java.io.IOException;

public class JsonToStringDeserializer extends JsonDeserializer<String> {

@Override
public String deserialize(JsonParser p, DeserializationContext ctx) throws IOException {
if (p.getCurrentToken() == JsonToken.VALUE_STRING) {
return p.getValueAsString();
}

return p.readValueAsTree().toString();
}
}
48 changes: 24 additions & 24 deletions src/test/java/com/epam/aidial/core/upstream/LoadBalancerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ public class LoadBalancerTest {
@Test
void testWeightedLoadBalancer() {
List<Upstream> upstreams = List.of(
new Upstream("endpoint1", null, 1, 0),
new Upstream("endpoint2", null, 9, 0)
new Upstream("endpoint1", null, null, 1, 0),
new Upstream("endpoint2", null, null, 9, 0)
);
WeightedRoundRobinBalancer balancer = new WeightedRoundRobinBalancer("model1", upstreams);

Expand All @@ -40,10 +40,10 @@ void testWeightedLoadBalancer() {
assertEquals(18, usage.get("endpoint2").getValue());

upstreams = List.of(
new Upstream("endpoint1", null, 1, 0),
new Upstream("endpoint2", null, 1, 0),
new Upstream("endpoint3", null, 1, 0),
new Upstream("endpoint4", null, 1, 0)
new Upstream("endpoint1", null, null, 1, 0),
new Upstream("endpoint2", null, null, 1, 0),
new Upstream("endpoint3", null, null, 1, 0),
new Upstream("endpoint4", null, null, 1, 0)
);
balancer = new WeightedRoundRobinBalancer("model1", upstreams);

Expand All @@ -66,10 +66,10 @@ void testWeightedLoadBalancer() {
assertEquals(25, usage.get("endpoint4").getValue());

upstreams = List.of(
new Upstream("endpoint1", null, 49, 0),
new Upstream("endpoint2", null, 44, 0),
new Upstream("endpoint3", null, 47, 0),
new Upstream("endpoint4", null, 59, 0)
new Upstream("endpoint1", null, null, 49, 0),
new Upstream("endpoint2", null, null, 44, 0),
new Upstream("endpoint3", null, null, 47, 0),
new Upstream("endpoint4", null, null, 59, 0)
);
balancer = new WeightedRoundRobinBalancer("model1", upstreams);

Expand All @@ -95,8 +95,8 @@ void testWeightedLoadBalancer() {
@Test
void testTieredLoadBalancer() {
List<Upstream> upstreams = List.of(
new Upstream("endpoint1", null, 1, 0),
new Upstream("endpoint2", null, 9, 1)
new Upstream("endpoint1", null, null, 1, 0),
new Upstream("endpoint2", null, null, 9, 1)
);
TieredBalancer balancer = new TieredBalancer("model1", upstreams);

Expand All @@ -111,8 +111,8 @@ void testTieredLoadBalancer() {
@Test
void testLoadBalancerFailure() throws InterruptedException {
List<Upstream> upstreams = List.of(
new Upstream("endpoint1", null, 1, 0),
new Upstream("endpoint2", null, 9, 1)
new Upstream("endpoint1", null, null, 1, 0),
new Upstream("endpoint2", null, null, 9, 1)
);
TieredBalancer balancer = new TieredBalancer("model1", upstreams);

Expand Down Expand Up @@ -141,8 +141,8 @@ void testLoadBalancerFailure() throws InterruptedException {
@Test
void testZeroWeightLoadBalancer() {
List<Upstream> upstreams = List.of(
new Upstream("endpoint1", null, 0, 1),
new Upstream("endpoint2", null, -9, 1)
new Upstream("endpoint1", null, null, 0, 1),
new Upstream("endpoint2", null, null, -9, 1)
);
WeightedRoundRobinBalancer balancer = new WeightedRoundRobinBalancer("model1", upstreams);

Expand All @@ -155,8 +155,8 @@ void testZeroWeightLoadBalancer() {
@Test
void test5xxErrorsHandling() {
List<Upstream> upstreams = List.of(
new Upstream("endpoint1", null, 1, 0),
new Upstream("endpoint2", null, 1, 1)
new Upstream("endpoint1", null, null, 1, 0),
new Upstream("endpoint2", null, null, 1, 1)
);
TieredBalancer balancer = new TieredBalancer("model1", upstreams);

Expand All @@ -183,8 +183,8 @@ void testUpstreamRefresh() {
Model model = new Model();
model.setName("model1");
model.setUpstreams(List.of(
new Upstream("endpoint1", null, 1, 1),
new Upstream("endpoint2", null, 1, 1)
new Upstream("endpoint1", null, null, 1, 1),
new Upstream("endpoint2", null, null, 1, 1)
));

models.put("model1", model);
Expand All @@ -208,8 +208,8 @@ void testUpstreamRefresh() {
Model model1 = new Model();
model1.setName("model1");
model1.setUpstreams(List.of(
new Upstream("endpoint2", null, 1, 1),
new Upstream("endpoint1", null, 1, 1)
new Upstream("endpoint2", null, null, 1, 1),
new Upstream("endpoint1", null, null, 1, 1)
));

models.put("model1", model1);
Expand All @@ -224,8 +224,8 @@ void testUpstreamRefresh() {
Model model2 = new Model();
model2.setName("model1");
model2.setUpstreams(List.of(
new Upstream("endpoint2", null, 5, 1),
new Upstream("endpoint1", null, 1, 1)
new Upstream("endpoint2", null, null, 5, 1),
new Upstream("endpoint1", null, null, 1, 1)
));

models.put("model1", model2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ void testUpstreamRouteWithRetry() {
Model model = new Model();
model.setName("model1");
model.setUpstreams(List.of(
new Upstream("endpoint1", null, 1, 1),
new Upstream("endpoint2", null, 1, 1),
new Upstream("endpoint3", null, 1, 1),
new Upstream("endpoint4", null, 1, 1)
new Upstream("endpoint1", null, null, 1, 1),
new Upstream("endpoint2", null, null, 1, 1),
new Upstream("endpoint3", null, null, 1, 1),
new Upstream("endpoint4", null, null, 1, 1)
));

UpstreamRoute route = upstreamRouteProvider.get(new DeploymentUpstreamProvider(model));
Expand Down Expand Up @@ -76,8 +76,8 @@ void testUpstreamRouteWithRetry2() {
Model model = new Model();
model.setName("model1");
model.setUpstreams(List.of(
new Upstream("endpoint1", null, 1, 1),
new Upstream("endpoint2", null, 1, 1)
new Upstream("endpoint1", null, null, 1, 1),
new Upstream("endpoint2", null, null, 1, 1)
));

UpstreamRoute route = upstreamRouteProvider.get(new DeploymentUpstreamProvider(model));
Expand Down
Loading

0 comments on commit 9f68c8c

Please sign in to comment.