Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: code interpreter #633

Merged
merged 6 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ dependencies {
implementation 'org.hibernate.validator:hibernate-validator:8.0.0.Final'
implementation 'org.glassfish:jakarta.el:4.0.2'
implementation 'jakarta.validation:jakarta.validation-api:3.0.2' // Ensure you have Jakarta Validation API dependency
implementation 'org.apache.httpcomponents.client5:httpclient5:5.4'
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved

testImplementation 'org.junit.jupiter:junit-jupiter-api:5.9.3'
testImplementation 'commons-io:commons-io:2.11.0'
testImplementation 'io.vertx:vertx-web-client:4.5.10'
testImplementation 'io.vertx:vertx-junit5:4.5.10'
testImplementation 'org.mockito:mockito-core:5.7.0'
testImplementation 'org.mockito:mockito-junit-jupiter:5.7.0'
testImplementation 'org.apache.httpcomponents.client5:httpclient5:5.4'
testImplementation('com.github.codemonstur:embedded-redis:1.4.3') {
exclude group: 'org.slf4j', module: 'slf4j-simple'
}
Expand Down
12 changes: 8 additions & 4 deletions server/src/main/java/com/epam/aidial/core/server/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.epam.aidial.core.server.security.AccessTokenValidator;
import com.epam.aidial.core.server.security.ApiKeyStore;
import com.epam.aidial.core.server.security.EncryptionService;
import com.epam.aidial.core.server.service.ApplicationOperatorService;
import com.epam.aidial.core.server.service.ApplicationService;
import com.epam.aidial.core.server.service.HeartbeatService;
import com.epam.aidial.core.server.service.InvitationService;
Expand All @@ -18,6 +19,7 @@
import com.epam.aidial.core.server.service.RuleService;
import com.epam.aidial.core.server.service.ShareService;
import com.epam.aidial.core.server.service.VertxTimerService;
import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService;
import com.epam.aidial.core.server.token.TokenStatsTracker;
import com.epam.aidial.core.server.tracing.DialTracingFactory;
import com.epam.aidial.core.server.upstream.UpstreamRouteProvider;
Expand Down Expand Up @@ -122,8 +124,9 @@ void start() throws Exception {
InvitationService invitationService = new InvitationService(resourceService, encryptionService, settings("invitations"));
ApiKeyStore apiKeyStore = new ApiKeyStore(resourceService, vertx);
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);
ApplicationService applicationService = new ApplicationService(vertx, client, redis,
encryptionService, resourceService, lockService, generator, settings("applications"));
ApplicationOperatorService operatorService = new ApplicationOperatorService(client, settings("applications"));
ApplicationService applicationService = new ApplicationService(vertx, redis, encryptionService,
resourceService, lockService, operatorService, generator, settings("applications"));
ShareService shareService = new ShareService(resourceService, invitationService, encryptionService, applicationService, configStore);
RuleService ruleService = new RuleService(resourceService);
AccessService accessService = new AccessService(encryptionService, shareService, ruleService, settings("access"));
Expand All @@ -133,7 +136,8 @@ void start() throws Exception {
PublicationService publicationService = new PublicationService(encryptionService, resourceService, accessService,
ruleService, notificationService, applicationService, resourceOperationService, generator, clock);
RateLimiter rateLimiter = new RateLimiter(vertx, resourceService);

CodeInterpreterService codeInterpreterService = new CodeInterpreterService(vertx, redis, resourceService,
accessService, encryptionService, operatorService, generator, settings("codeInterpreter"));

TokenStatsTracker tokenStatsTracker = new TokenStatsTracker(vertx, resourceService);

Expand All @@ -143,7 +147,7 @@ void start() throws Exception {
rateLimiter, upstreamRouteProvider, accessTokenValidator,
storage, encryptionService, apiKeyStore, tokenStatsTracker, resourceService, invitationService,
shareService, publicationService, accessService, lockService, resourceOperationService, ruleService,
notificationService, applicationService, heartbeatService, version());
notificationService, applicationService, codeInterpreterService, heartbeatService, version());

server = vertx.createHttpServer(new HttpServerOptions(settings("server"))).requestHandler(proxy);
open(server, HttpServer::listen);
Expand Down
2 changes: 2 additions & 0 deletions server/src/main/java/com/epam/aidial/core/server/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.epam.aidial.core.server.service.ResourceOperationService;
import com.epam.aidial.core.server.service.RuleService;
import com.epam.aidial.core.server.service.ShareService;
import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService;
import com.epam.aidial.core.server.token.TokenStatsTracker;
import com.epam.aidial.core.server.upstream.UpstreamRouteProvider;
import com.epam.aidial.core.server.util.ProxyUtil;
Expand Down Expand Up @@ -88,6 +89,7 @@ public class Proxy implements Handler<HttpServerRequest> {
private final RuleService ruleService;
private final NotificationService notificationService;
private final ApplicationService applicationService;
private final CodeInterpreterService codeInterpreterService;
private final HeartbeatService heartbeatService;
private final String version;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,20 @@ public Future<?> respond(HttpStatus status, String contentType, Object object) {
}

public Future<?> respond(HttpStatus status, String body) {
return respond(status.getCode(), body);
}

public Future<?> respond(int status, String body) {
if (body == null) {
body = "";
}

if (status != HttpStatus.OK) {
if (status != HttpStatus.OK.getCode()) {
log.warn("Responding with error. Project: {}. Trace: {}. Span: {}. Status: {}. Body: {}", getProject(), traceId, spanId, status,
body.length() > LOG_MAX_ERROR_LENGTH ? body.substring(0, LOG_MAX_ERROR_LENGTH) : body);
}

response.setStatusCode(status.getCode()).end(body);
response.setStatusCode(status).end(body);
return Future.succeededFuture();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package com.epam.aidial.core.server.controller;

import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecute;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFile;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterInputFile;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterOutputFile;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSession;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSessionId;
import com.epam.aidial.core.server.service.PermissionDeniedException;
import com.epam.aidial.core.server.service.ResourceNotFoundException;
import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterError;
import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.server.vertx.stream.InputStreamAdapter;
import com.epam.aidial.core.server.vertx.stream.InputStreamReader;
import com.epam.aidial.core.storage.http.HttpException;
import com.epam.aidial.core.storage.http.HttpStatus;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpServerFileUpload;
import io.vertx.core.http.HttpServerResponse;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.io.InputStream;

@Slf4j
class CodeInterpreterController {

private final ProxyContext context;
private final Vertx vertx;
private final CodeInterpreterService service;

public CodeInterpreterController(ProxyContext context) {
this.context = context;
this.vertx = context.getProxy().getVertx();
this.service = context.getProxy().getCodeInterpreterService();
}

Future<?> openSession() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterSessionId data = convertJson(body, CodeInterpreterSessionId.class);
return vertx.executeBlocking(() -> service.openSession(context, data.getSessionId()), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> closeSession() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterSessionId data = convertJson(body, CodeInterpreterSessionId.class);
return vertx.executeBlocking(() -> service.closeSession(context, data.getSessionId()), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> executeCode() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterExecute data = convertJson(body, CodeInterpreterExecute.class);
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved
return vertx.executeBlocking(() -> service.executeCode(context, data), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> uploadFile() {
context.getRequest()
.setExpectMultipart(true)
.uploadHandler(upload -> {
// do not move inside execute blocking, otherwise you can miss the beginning of file
InputStream stream = new InputStreamAdapter(upload);
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved
vertx.executeBlocking(() -> uploadFile(upload, stream), false)
.onSuccess(this::respondJson)
.onFailure(this::respondError);
});

return Future.succeededFuture();
}

@SneakyThrows
private CodeInterpreterFile uploadFile(HttpServerFileUpload upload, InputStream stream) {
String sessionId = context.getRequest().getParam("session_id");
String fileName = upload.filename();

if (sessionId == null) {
throw new IllegalArgumentException("Missing session_id query param");
}

if (fileName == null) {
throw new IllegalArgumentException("Missing filename in multipart upload");
}

return service.uploadFile(context, sessionId, fileName, stream);
}

Future<?> downloadFile() {
context.getRequest().body()
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved
.compose(buffer -> vertx.executeBlocking(() -> downloadFile(buffer), false))
.onFailure(this::respondError);

return Future.succeededFuture();
}

private Void downloadFile(Buffer body) {
CodeInterpreterFile data = convertJson(body, CodeInterpreterFile.class);
HttpServerResponse response = context.getResponse();

return service.downloadFile(context, data.getSessionId(), data.getPath(), (stream, size) -> {
response.putHeader(HttpHeaders.CONTENT_LENGTH, Long.toString(size));
return new InputStreamReader(vertx, stream)
.pipe()
.endOnFailure(false)
.to(response);
});
}

Future<?> listFiles() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterSessionId data = convertJson(body, CodeInterpreterSessionId.class);
return vertx.executeBlocking(() -> service.listFiles(context, data.getSessionId()), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> transferInputFile() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterInputFile data = convertJson(body, CodeInterpreterInputFile.class);
return vertx.executeBlocking(() -> service.transferInputFile(context, data), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> transferOutputFile() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterOutputFile data = convertJson(body, CodeInterpreterOutputFile.class);
return vertx.executeBlocking(() -> service.transferOutputFile(context, data), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

private void respondJson(Object data) {
if (data instanceof CodeInterpreterSession session) {
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved
session.setDeploymentId(null);
session.setDeploymentUrl(null);
session.setUsedAt(null);
}

context.respond(HttpStatus.OK, data);
}

private void respondError(Throwable error) {
HttpServerResponse response = context.getResponse();
if (response.headWritten()) {
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved
response.reset();
} else if (error instanceof IllegalArgumentException) {
context.respond(HttpStatus.BAD_REQUEST, error.getMessage());
} else if (error instanceof PermissionDeniedException) {
context.respond(HttpStatus.FORBIDDEN, error.getMessage());
} else if (error instanceof ResourceNotFoundException) {
context.respond(HttpStatus.NOT_FOUND, error.getMessage());
} else if (error instanceof HttpException e) {
context.respond(e.getStatus(), e.getMessage());
} else if (error instanceof CodeInterpreterError e) {
context.respond(e.getStatus(), e.getMessage());
} else {
log.error("Failed to handle code interpreter request", error);
context.respond(error, "Internal error");
}
}

private static <T> T convertJson(Buffer body, Class<T> clazz) {
try {
T result = ProxyUtil.convertToObject(body, clazz);

if (result == null) {
throw new IllegalArgumentException("No JSON body");
}

return result;
} catch (Exception e) {
throw new IllegalArgumentException("Not valid JSON body");
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ public class ControllerSelector {
private static final Pattern USER_INFO = Pattern.compile("^/v1/user/info$");

private static final Pattern APP_SCHEMAS = Pattern.compile("^/v1/application_type_schemas/(schemas|schema|meta_schema)?");
private static final Pattern CODE_INTERPRETER = Pattern.compile("^/v1/ops/code_interpreter/"
+ "(open_session|close_session|execute_code|"
+ "upload_file|download_file|list_files|"
+ "transfer_input_file|transfer_output_file)$");

static {
// GET routes
Expand Down Expand Up @@ -283,6 +287,22 @@ public class ControllerSelector {
default -> null;
};
});
post(CODE_INTERPRETER, (proxy, context, pathMatcher) -> {
String operation = pathMatcher.group(1);
CodeInterpreterController controller = new CodeInterpreterController(context);

return switch (operation) {
case "open_session" -> controller::openSession;
case "close_session" -> controller::closeSession;
case "execute_code" -> controller::executeCode;
case "upload_file" -> controller::uploadFile;
case "download_file" -> controller::downloadFile;
case "list_files" -> controller::listFiles;
case "transfer_input_file" -> controller::transferInputFile;
case "transfer_output_file" -> controller::transferOutputFile;
default -> null;
};
});
// DELETE routes
delete(PATTERN_FILES, (proxy, context, pathMatcher) -> {
ResourceController controller = new ResourceController(proxy, context, false);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.epam.aidial.core.server.data;

import lombok.Data;

import javax.annotation.Nullable;


@Data
public class AuthBucket {

/**
* The encrypted bucket location for the original JWT or API_KEY.
*/
String userBucket;
/**
* The bucket location for the original JWT or API_KEY.
*/
String userBucketLocation;

/**
* The encrypted bucket location for the application from PER_REQUEST_KEY if present.
*/
@Nullable
String appBucket;

/**
* The bucket location for the application from PER_REQUEST_KEY if present.
*/
@Nullable
String appBucketLocation;
}
Loading
Loading