Skip to content

Commit

Permalink
feat: exposed rate feature in the listings (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Nov 30, 2023
1 parent 0f08f43 commit 2a84361
Show file tree
Hide file tree
Showing 15 changed files with 121 additions and 232 deletions.
2 changes: 1 addition & 1 deletion src/main/java/com/epam/aidial/core/config/Assistants.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
@Data
public class Assistants {
private String endpoint;
private String rateEndpoint;
private Features features;
private Map<String, Assistant> assistants = Map.of();
}
31 changes: 31 additions & 0 deletions src/main/java/com/epam/aidial/core/config/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class Config {
public static final String ASSISTANT = "assistant";

// maintain the order of routes defined in the config
private LinkedHashMap<String, Route> routes = new LinkedHashMap<>();
private Map<String, Model> models = Map.of();
Expand All @@ -17,4 +19,33 @@ public class Config {
private Assistants assistant = new Assistants();
private Map<String, Key> keys = Map.of();
private Map<String, Role> roles = Map.of();


public Deployment selectDeployment(String deploymentId) {
Application application = applications.get(deploymentId);
if (application != null) {
return application;
}

Model model = models.get(deploymentId);
if (model != null) {
return model;
}

Assistants assistants = assistant;
Assistant assistant = assistants.getAssistants().get(deploymentId);
if (assistant != null) {
return assistant;
}

if (assistants.getEndpoint() != null && ASSISTANT.equals(deploymentId)) {
Assistant baseAssistant = new Assistant();
baseAssistant.setName(ASSISTANT);
baseAssistant.setEndpoint(assistants.getEndpoint());
baseAssistant.setFeatures(assistants.getFeatures());
return baseAssistant;
}

return null;
}
}
1 change: 0 additions & 1 deletion src/main/java/com/epam/aidial/core/config/Deployment.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ public abstract class Deployment {
* API key is forwarded by default.
*/
private boolean forwardApiKey = true;
private String rateEndpoint;
private Features features;
private List<String> inputAttachmentTypes;
private Integer maxInputAttachments;
Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/epam/aidial/core/config/Features.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

@Data
public class Features {
private String rateEndpoint;
private String tokenizeEndpoint;
private String truncatePromptEndpoint;
}
29 changes: 27 additions & 2 deletions src/main/java/com/epam/aidial/core/config/FileConfigStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,17 @@ private void load(boolean fail) {
addon.setName(name);
}

for (Map.Entry<String, Assistant> entry : config.getAssistant().getAssistants().entrySet()) {
Assistants assistants = config.getAssistant();
for (Map.Entry<String, Assistant> entry : assistants.getAssistants().entrySet()) {
String name = entry.getKey();
Assistant assistant = entry.getValue();
assistant.setName(name);

if (assistant.getEndpoint() == null) {
assistant.setEndpoint(config.getAssistant().getEndpoint());
assistant.setEndpoint(assistants.getEndpoint());
}

setMissingFeatures(assistant, assistants.getFeatures());
}

for (Map.Entry<String, Application> entry : config.getApplications().entrySet()) {
Expand Down Expand Up @@ -118,4 +121,26 @@ private static InputStream openStream(String path) {
return ConfigStore.class.getClassLoader().getResourceAsStream(path);
}
}

private static void setMissingFeatures(Deployment model, Features features) {
if (features == null) {
return;
}

Features modelFeatures = model.getFeatures();
if (modelFeatures == null) {
model.setFeatures(features);
return;
}

if (modelFeatures.getRateEndpoint() == null) {
modelFeatures.setRateEndpoint(features.getRateEndpoint());
}
if (modelFeatures.getTokenizeEndpoint() == null) {
modelFeatures.setTokenizeEndpoint(features.getTokenizeEndpoint());
}
if (modelFeatures.getTruncatePromptEndpoint() == null) {
modelFeatures.setTruncatePromptEndpoint(features.getTruncatePromptEndpoint());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.epam.aidial.core.Proxy;
import com.epam.aidial.core.ProxyContext;
import com.epam.aidial.core.config.Model;
import com.epam.aidial.core.config.Deployment;
import io.vertx.core.http.HttpMethod;
import lombok.experimental.UtilityClass;

Expand Down Expand Up @@ -158,38 +158,46 @@ private static Controller selectPost(Proxy proxy, ProxyContext context, String p
match = match(PATTERN_RATE_RESPONSE, path);
if (match != null) {
String deploymentId = match.group(1);
RateResponseController controller = new RateResponseController(proxy, context);
return () -> controller.handle(deploymentId);

Function<Deployment, String> getter = (model) -> {
return Optional.ofNullable(model)
.map(d -> d.getFeatures())
.map(t -> t.getRateEndpoint())
.orElse(null);
};

DeploymentFeatureController controller = new DeploymentFeatureController(proxy, context);
return () -> controller.handle(deploymentId, getter, false);
}

match = match(PATTERN_TOKENIZE, path);
if (match != null) {
String deploymentId = match.group(1);

Function<Model, String> endpointGetter = (model) -> {
Function<Deployment, String> getter = (model) -> {
return Optional.ofNullable(model)
.map(d -> d.getFeatures())
.map(t -> t.getTokenizeEndpoint())
.orElse(null);
};

ModelEndpointController controller = new ModelEndpointController(proxy, context);
return () -> controller.handle(deploymentId, endpointGetter);
DeploymentFeatureController controller = new DeploymentFeatureController(proxy, context);
return () -> controller.handle(deploymentId, getter, true);
}

match = match(PATTERN_TRUNCATE_PROMPT, path);
if (match != null) {
String deploymentId = match.group(1);

Function<Model, String> endpointGetter = (model) -> {
Function<Deployment, String> getter = (model) -> {
return Optional.ofNullable(model)
.map(d -> d.getFeatures())
.map(t -> t.getTruncatePromptEndpoint())
.orElse(null);
};

ModelEndpointController controller = new ModelEndpointController(proxy, context);
return () -> controller.handle(deploymentId, endpointGetter);
DeploymentFeatureController controller = new DeploymentFeatureController(proxy, context);
return () -> controller.handle(deploymentId, getter, true);
}

return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.epam.aidial.core.Proxy;
import com.epam.aidial.core.ProxyContext;
import com.epam.aidial.core.config.Model;
import com.epam.aidial.core.config.Deployment;
import com.epam.aidial.core.util.BufferingReadStream;
import com.epam.aidial.core.util.HttpStatus;
import com.epam.aidial.core.util.ProxyUtil;
Expand All @@ -22,30 +22,40 @@

@Slf4j
@RequiredArgsConstructor
public class ModelEndpointController {
public class DeploymentFeatureController {

private final Proxy proxy;
private final ProxyContext context;

public void handle(String deploymentId, Function<Model, String> endpointGetter) {
Model deployment = context.getConfig().getModels().get(deploymentId);
public void handle(String deploymentId, Function<Deployment, String> endpointGetter, boolean requireEndpoint) {
Deployment deployment = context.getConfig().selectDeployment(deploymentId);

String endpoint = endpointGetter.apply(deployment);

if (endpoint == null || !DeploymentController.hasAccessByUserRoles(context, deployment)) {
if (deployment == null || !DeploymentController.hasAccessByUserRoles(context, deployment)) {
context.respond(HttpStatus.FORBIDDEN, "Forbidden deployment");
return;
}

String endpoint = endpointGetter.apply(deployment);
context.setDeployment(deployment);
context.getRequest().body()
.onSuccess(buffer -> this.handleRequestBody(endpoint, buffer))
.onSuccess(requestBody -> this.handleRequestBody(endpoint, requireEndpoint, requestBody))
.onFailure(this::handleRequestBodyError);
}

@SneakyThrows
private void handleRequestBody(String endpoint, Buffer requestBody) {
private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer requestBody) {
context.setRequestBody(requestBody);

if (endpoint == null) {
if (requireEndpoint) {
context.respond(HttpStatus.FORBIDDEN, "Forbidden deployment");
} else {
context.respond(HttpStatus.OK);
proxy.getLogStore().save(context);
}
return;
}

RequestOptions options = new RequestOptions()
.setAbsoluteURI(new URL(endpoint))
.setMethod(context.getRequest().method());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import com.epam.aidial.core.ProxyContext;
import com.epam.aidial.core.config.Addon;
import com.epam.aidial.core.config.Assistant;
import com.epam.aidial.core.config.Assistants;
import com.epam.aidial.core.config.Config;
import com.epam.aidial.core.config.Deployment;
import com.epam.aidial.core.config.Model;
Expand Down Expand Up @@ -52,7 +51,6 @@
@RequiredArgsConstructor
public class DeploymentPostController {

public static final String ASSISTANT = "assistant";
private static final Set<Integer> RETRIABLE_HTTP_CODES = Set.of(HttpStatus.TOO_MANY_REQUESTS.getCode(),
HttpStatus.BAD_GATEWAY.getCode(), HttpStatus.GATEWAY_TIMEOUT.getCode(),
HttpStatus.SERVICE_UNAVAILABLE.getCode());
Expand All @@ -66,12 +64,17 @@ public Future<?> handle(String deploymentId, String deploymentApi) {
return context.respond(HttpStatus.UNSUPPORTED_MEDIA_TYPE, "Only application/json is supported");
}

Deployment deployment = select(deploymentId, deploymentApi);
context.setDeployment(deployment);
Deployment deployment = context.getConfig().selectDeployment(deploymentId);
if (!isValidDeploymentApi(deployment, deploymentApi)) {
deployment = null;
}

if (deployment == null || (!isAssistant(deployment) && !DeploymentController.hasAccess(context, deployment))) {
if (deployment == null || (!isBaseAssistant(deployment) && !DeploymentController.hasAccess(context, deployment))) {
return context.respond(HttpStatus.FORBIDDEN, "Forbidden deployment");
}

context.setDeployment(deployment);

RateLimitResult rateLimitResult;
if (deployment instanceof Model && (rateLimitResult = proxy.getRateLimiter().limit(context)).status() != HttpStatus.OK) {
// Returning an error similar to the Azure format.
Expand Down Expand Up @@ -298,8 +301,7 @@ private void handleResponseError(Throwable error) {
context.getResponse().reset(); // drop connection, so that partial client response won't seem complete
}

private Deployment select(String deploymentId, String deploymentApi) {
Config config = context.getConfig();
private static boolean isValidDeploymentApi(Deployment deployment, String deploymentApi) {
ModelType type = switch (deploymentApi) {
case "completions" -> ModelType.COMPLETION;
case "chat/completions" -> ModelType.CHAT;
Expand All @@ -308,33 +310,16 @@ private Deployment select(String deploymentId, String deploymentApi) {
};

if (type == null) {
return null;
}

Model model = config.getModels().get(deploymentId);
if (model != null) {
return (type == model.getType()) ? model : null;
}

if (type != ModelType.CHAT) {
return null;
}

Assistants assistants = config.getAssistant();

if (assistants.getEndpoint() != null && ASSISTANT.equals(deploymentId)) {
Assistant assistant = new Assistant();
assistant.setName(ASSISTANT);
assistant.setEndpoint(assistants.getEndpoint());
return assistant;
return false;
}

Assistant assistant = assistants.getAssistants().get(deploymentId);
if (assistant != null) {
return assistant;
// Models support all APIs
if (deployment instanceof Model model) {
return type == model.getType();
}

return config.getApplications().get(deploymentId);
// Assistants and applications only support chat API
return type == ModelType.CHAT;
}

private static String buildUri(ProxyContext context) {
Expand Down Expand Up @@ -426,7 +411,7 @@ private static void deletePrompt(ArrayNode messages) {
}
}

private static boolean isAssistant(Deployment deployment) {
return deployment.getName().equals(ASSISTANT);
private static boolean isBaseAssistant(Deployment deployment) {
return deployment.getName().equals(Config.ASSISTANT);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ private static FeaturesData createFeatures(Features features) {
return null;
}
FeaturesData data = new FeaturesData();
data.setRate(features.getRateEndpoint() != null);
data.setTokenize(features.getTokenizeEndpoint() != null);
data.setTruncatePrompt(features.getTruncatePromptEndpoint() != null);
return data;
Expand Down
Loading

0 comments on commit 2a84361

Please sign in to comment.