From 29f307992ce7d8b1ab8b66248bcfe7b20b8c560c Mon Sep 17 00:00:00 2001 From: Artsiom Korzun Date: Thu, 9 Jan 2025 10:48:42 +0100 Subject: [PATCH 1/5] add code interpreter feature --- server/build.gradle | 2 +- .../com/epam/aidial/core/server/AiDial.java | 12 +- .../com/epam/aidial/core/server/Proxy.java | 2 + .../epam/aidial/core/server/ProxyContext.java | 8 +- .../controller/CodeInterpreterController.java | 215 ++++++++++ .../server/controller/ControllerSelector.java | 20 + .../aidial/core/server/data/AuthBucket.java | 31 ++ .../core/server/data/ResourceTypes.java | 5 +- .../CodeInterpreterExecute.java | 15 + .../CodeInterpreterExecution.java | 14 + .../codeinterpreter/CodeInterpreterFile.java | 12 + .../codeinterpreter/CodeInterpreterFiles.java | 12 + .../CodeInterpreterInputFile.java | 12 + .../CodeInterpreterOutputFile.java | 12 + .../CodeInterpreterSession.java | 13 + .../CodeInterpreterSessionId.java | 10 + .../service/ApplicationOperatorService.java | 33 +- .../server/service/ApplicationService.java | 4 +- .../CodeInterpreterClient.java | 117 ++++++ .../codeinterpreter/CodeInterpreterError.java | 13 + .../CodeInterpreterService.java | 380 ++++++++++++++++++ .../core/server/util/BucketBuilder.java | 26 ++ .../vertx/stream/InputStreamAdapter.java | 122 ++++++ .../core/server/CodeInterpreterApiTest.java | 132 ++++++ .../aidial/core/server/ResourceBaseTest.java | 3 + .../core/storage/service/ResourceService.java | 20 +- 26 files changed, 1225 insertions(+), 20 deletions(-) create mode 100644 server/src/main/java/com/epam/aidial/core/server/controller/CodeInterpreterController.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/data/AuthBucket.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecute.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecution.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterFile.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterFiles.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterInputFile.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterOutputFile.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterSession.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterSessionId.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterClient.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterError.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterService.java create mode 100644 server/src/main/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapter.java create mode 100644 server/src/test/java/com/epam/aidial/core/server/CodeInterpreterApiTest.java diff --git a/server/build.gradle b/server/build.gradle index a81fb2931..58978540a 100644 --- a/server/build.gradle +++ b/server/build.gradle @@ -36,6 +36,7 @@ dependencies { implementation 'org.hibernate.validator:hibernate-validator:8.0.0.Final' implementation 'org.glassfish:jakarta.el:4.0.2' implementation 'jakarta.validation:jakarta.validation-api:3.0.2' // Ensure you have Jakarta Validation API dependency + implementation 'org.apache.httpcomponents.client5:httpclient5:5.4' testImplementation 'org.junit.jupiter:junit-jupiter-api:5.9.3' testImplementation 'commons-io:commons-io:2.11.0' @@ -43,7 +44,6 @@ dependencies { testImplementation 'io.vertx:vertx-junit5:4.5.10' testImplementation 'org.mockito:mockito-core:5.7.0' testImplementation 'org.mockito:mockito-junit-jupiter:5.7.0' - testImplementation 'org.apache.httpcomponents.client5:httpclient5:5.4' testImplementation('com.github.codemonstur:embedded-redis:1.4.3') { exclude group: 'org.slf4j', module: 'slf4j-simple' } diff --git a/server/src/main/java/com/epam/aidial/core/server/AiDial.java b/server/src/main/java/com/epam/aidial/core/server/AiDial.java index e87b1bf84..5433fc30d 100644 --- a/server/src/main/java/com/epam/aidial/core/server/AiDial.java +++ b/server/src/main/java/com/epam/aidial/core/server/AiDial.java @@ -9,6 +9,7 @@ import com.epam.aidial.core.server.security.AccessTokenValidator; import com.epam.aidial.core.server.security.ApiKeyStore; import com.epam.aidial.core.server.security.EncryptionService; +import com.epam.aidial.core.server.service.ApplicationOperatorService; import com.epam.aidial.core.server.service.ApplicationService; import com.epam.aidial.core.server.service.HeartbeatService; import com.epam.aidial.core.server.service.InvitationService; @@ -18,6 +19,7 @@ import com.epam.aidial.core.server.service.RuleService; import com.epam.aidial.core.server.service.ShareService; import com.epam.aidial.core.server.service.VertxTimerService; +import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService; import com.epam.aidial.core.server.token.TokenStatsTracker; import com.epam.aidial.core.server.tracing.DialTracingFactory; import com.epam.aidial.core.server.upstream.UpstreamRouteProvider; @@ -122,8 +124,9 @@ void start() throws Exception { InvitationService invitationService = new InvitationService(resourceService, encryptionService, settings("invitations")); ApiKeyStore apiKeyStore = new ApiKeyStore(resourceService, vertx); ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore); - ApplicationService applicationService = new ApplicationService(vertx, client, redis, - encryptionService, resourceService, lockService, generator, settings("applications")); + ApplicationOperatorService operatorService = new ApplicationOperatorService(client, settings("applications")); + ApplicationService applicationService = new ApplicationService(vertx, redis, encryptionService, + resourceService, lockService, operatorService, generator, settings("applications")); ShareService shareService = new ShareService(resourceService, invitationService, encryptionService, applicationService, configStore); RuleService ruleService = new RuleService(resourceService); AccessService accessService = new AccessService(encryptionService, shareService, ruleService, settings("access")); @@ -133,7 +136,8 @@ void start() throws Exception { PublicationService publicationService = new PublicationService(encryptionService, resourceService, accessService, ruleService, notificationService, applicationService, resourceOperationService, generator, clock); RateLimiter rateLimiter = new RateLimiter(vertx, resourceService); - + CodeInterpreterService codeInterpreterService = new CodeInterpreterService(vertx, redis, resourceService, + accessService, encryptionService, operatorService, generator, settings("codeInterpreter")); TokenStatsTracker tokenStatsTracker = new TokenStatsTracker(vertx, resourceService); @@ -143,7 +147,7 @@ void start() throws Exception { rateLimiter, upstreamRouteProvider, accessTokenValidator, storage, encryptionService, apiKeyStore, tokenStatsTracker, resourceService, invitationService, shareService, publicationService, accessService, lockService, resourceOperationService, ruleService, - notificationService, applicationService, heartbeatService, version()); + notificationService, applicationService, codeInterpreterService, heartbeatService, version()); server = vertx.createHttpServer(new HttpServerOptions(settings("server"))).requestHandler(proxy); open(server, HttpServer::listen); diff --git a/server/src/main/java/com/epam/aidial/core/server/Proxy.java b/server/src/main/java/com/epam/aidial/core/server/Proxy.java index 7449864fd..94d8a3043 100644 --- a/server/src/main/java/com/epam/aidial/core/server/Proxy.java +++ b/server/src/main/java/com/epam/aidial/core/server/Proxy.java @@ -21,6 +21,7 @@ import com.epam.aidial.core.server.service.ResourceOperationService; import com.epam.aidial.core.server.service.RuleService; import com.epam.aidial.core.server.service.ShareService; +import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService; import com.epam.aidial.core.server.token.TokenStatsTracker; import com.epam.aidial.core.server.upstream.UpstreamRouteProvider; import com.epam.aidial.core.server.util.ProxyUtil; @@ -88,6 +89,7 @@ public class Proxy implements Handler { private final RuleService ruleService; private final NotificationService notificationService; private final ApplicationService applicationService; + private final CodeInterpreterService codeInterpreterService; private final HeartbeatService heartbeatService; private final String version; diff --git a/server/src/main/java/com/epam/aidial/core/server/ProxyContext.java b/server/src/main/java/com/epam/aidial/core/server/ProxyContext.java index 47525ccf6..e250db99c 100644 --- a/server/src/main/java/com/epam/aidial/core/server/ProxyContext.java +++ b/server/src/main/java/com/epam/aidial/core/server/ProxyContext.java @@ -146,16 +146,20 @@ public Future respond(HttpStatus status, String contentType, Object object) { } public Future respond(HttpStatus status, String body) { + return respond(status.getCode(), body); + } + + public Future respond(int status, String body) { if (body == null) { body = ""; } - if (status != HttpStatus.OK) { + if (status != HttpStatus.OK.getCode()) { log.warn("Responding with error. Project: {}. Trace: {}. Span: {}. Status: {}. Body: {}", getProject(), traceId, spanId, status, body.length() > LOG_MAX_ERROR_LENGTH ? body.substring(0, LOG_MAX_ERROR_LENGTH) : body); } - response.setStatusCode(status.getCode()).end(body); + response.setStatusCode(status).end(body); return Future.succeededFuture(); } diff --git a/server/src/main/java/com/epam/aidial/core/server/controller/CodeInterpreterController.java b/server/src/main/java/com/epam/aidial/core/server/controller/CodeInterpreterController.java new file mode 100644 index 000000000..a2b5effc9 --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/controller/CodeInterpreterController.java @@ -0,0 +1,215 @@ +package com.epam.aidial.core.server.controller; + +import com.epam.aidial.core.server.ProxyContext; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecute; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFile; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterInputFile; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterOutputFile; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSession; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSessionId; +import com.epam.aidial.core.server.service.PermissionDeniedException; +import com.epam.aidial.core.server.service.ResourceNotFoundException; +import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterError; +import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService; +import com.epam.aidial.core.server.util.ProxyUtil; +import com.epam.aidial.core.server.vertx.stream.InputStreamAdapter; +import com.epam.aidial.core.server.vertx.stream.InputStreamReader; +import com.epam.aidial.core.storage.http.HttpException; +import com.epam.aidial.core.storage.http.HttpStatus; +import io.vertx.core.Future; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.HttpHeaders; +import io.vertx.core.http.HttpServerFileUpload; +import io.vertx.core.http.HttpServerResponse; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; + +import java.io.InputStream; + +@Slf4j +class CodeInterpreterController { + + private final ProxyContext context; + private final Vertx vertx; + private final CodeInterpreterService service; + + public CodeInterpreterController(ProxyContext context) { + this.context = context; + this.vertx = context.getProxy().getVertx(); + this.service = context.getProxy().getCodeInterpreterService(); + } + + Future openSession() { + context.getRequest() + .body() + .compose(body -> { + CodeInterpreterSessionId data = convertJson(body, CodeInterpreterSessionId.class); + return vertx.executeBlocking(() -> service.openSession(context, data.getSessionId()), false); + }) + .onSuccess(this::respondJson) + .onFailure(this::respondError); + + return Future.succeededFuture(); + } + + Future closeSession() { + context.getRequest() + .body() + .compose(body -> { + CodeInterpreterSessionId data = convertJson(body, CodeInterpreterSessionId.class); + return vertx.executeBlocking(() -> service.closeSession(context, data.getSessionId()), false); + }) + .onSuccess(this::respondJson) + .onFailure(this::respondError); + + return Future.succeededFuture(); + } + + Future executeCode() { + context.getRequest() + .body() + .compose(body -> { + CodeInterpreterExecute data = convertJson(body, CodeInterpreterExecute.class); + return vertx.executeBlocking(() -> service.executeCode(context, data), false); + }) + .onSuccess(this::respondJson) + .onFailure(this::respondError); + + return Future.succeededFuture(); + } + + Future uploadFile() { + context.getRequest() + .setExpectMultipart(true) + .uploadHandler(upload -> { + // do not move inside execute blocking, otherwise you can miss the beginning of file + InputStream stream = new InputStreamAdapter(upload); + vertx.executeBlocking(() -> uploadFile(upload, stream), false) + .onSuccess(this::respondJson) + .onFailure(this::respondError); + }); + + return Future.succeededFuture(); + } + + @SneakyThrows + private CodeInterpreterFile uploadFile(HttpServerFileUpload upload, InputStream stream) { + String sessionId = context.getRequest().getParam("session_id"); + String fileName = upload.filename(); + + if (sessionId == null) { + throw new IllegalArgumentException("Missing session_id query param"); + } + + if (fileName == null) { + throw new IllegalArgumentException("Missing filename in multipart upload"); + } + + return service.uploadFile(context, sessionId, fileName, stream); + } + + Future downloadFile() { + context.getRequest().body() + .compose(buffer -> vertx.executeBlocking(() -> downloadFile(buffer), false)) + .onFailure(this::respondError); + + return Future.succeededFuture(); + } + + private Void downloadFile(Buffer body) { + CodeInterpreterFile data = convertJson(body, CodeInterpreterFile.class); + HttpServerResponse response = context.getResponse(); + + return service.downloadFile(context, data.getSessionId(), data.getPath(), (stream, size) -> { + response.putHeader(HttpHeaders.CONTENT_LENGTH, Long.toString(size)); + return new InputStreamReader(vertx, stream) + .pipe() + .endOnFailure(false) + .to(response); + }); + } + + Future listFiles() { + context.getRequest() + .body() + .compose(body -> { + CodeInterpreterSessionId data = convertJson(body, CodeInterpreterSessionId.class); + return vertx.executeBlocking(() -> service.listFiles(context, data.getSessionId()), false); + }) + .onSuccess(this::respondJson) + .onFailure(this::respondError); + + return Future.succeededFuture(); + } + + Future transferInputFile() { + context.getRequest() + .body() + .compose(body -> { + CodeInterpreterInputFile data = convertJson(body, CodeInterpreterInputFile.class); + return vertx.executeBlocking(() -> service.transferInputFile(context, data), false); + }) + .onSuccess(this::respondJson) + .onFailure(this::respondError); + + return Future.succeededFuture(); + } + + Future transferOutputFile() { + context.getRequest() + .body() + .compose(body -> { + CodeInterpreterOutputFile data = convertJson(body, CodeInterpreterOutputFile.class); + return vertx.executeBlocking(() -> service.transferOutputFile(context, data), false); + }) + .onSuccess(this::respondJson) + .onFailure(this::respondError); + + return Future.succeededFuture(); + } + + private void respondJson(Object data) { + if (data instanceof CodeInterpreterSession session) { + session.setDeploymentId(null); + session.setDeploymentUrl(null); + session.setUsedAt(null); + } + + context.respond(HttpStatus.OK, data); + } + + private void respondError(Throwable error) { + HttpServerResponse response = context.getResponse(); + if (response.headWritten()) { + response.reset(); + } else if (error instanceof IllegalArgumentException) { + context.respond(HttpStatus.BAD_REQUEST, error.getMessage()); + } else if (error instanceof PermissionDeniedException) { + context.respond(HttpStatus.FORBIDDEN, error.getMessage()); + } else if (error instanceof ResourceNotFoundException) { + context.respond(HttpStatus.NOT_FOUND, error.getMessage()); + } else if (error instanceof HttpException e) { + context.respond(e.getStatus(), e.getMessage()); + } else if (error instanceof CodeInterpreterError e) { + context.respond(e.getStatus(), e.getMessage()); + } else { + log.error("Failed to handle code interpreter request", error); + context.respond(error, "Internal error"); + } + } + + private static T convertJson(Buffer body, Class clazz) { + try { + T result = ProxyUtil.convertToObject(body, clazz); + + if (result == null) { + throw new IllegalArgumentException("No JSON body"); + } + + return result; + } catch (Exception e) { + throw new IllegalArgumentException("Not valid JSON body"); + } + } +} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/controller/ControllerSelector.java b/server/src/main/java/com/epam/aidial/core/server/controller/ControllerSelector.java index b037372e8..3b3142f5e 100644 --- a/server/src/main/java/com/epam/aidial/core/server/controller/ControllerSelector.java +++ b/server/src/main/java/com/epam/aidial/core/server/controller/ControllerSelector.java @@ -75,6 +75,10 @@ public class ControllerSelector { private static final Pattern USER_INFO = Pattern.compile("^/v1/user/info$"); private static final Pattern APP_SCHEMAS = Pattern.compile("^/v1/application_type_schemas/(schemas|schema|meta_schema)?"); + private static final Pattern CODE_INTERPRETER = Pattern.compile("^/v1/ops/code_interpreter/" + + "(open_session|close_session|execute_code|" + + "upload_file|download_file|list_files|" + + "transfer_input_file|transfer_output_file)$"); static { // GET routes @@ -283,6 +287,22 @@ public class ControllerSelector { default -> null; }; }); + post(CODE_INTERPRETER, (proxy, context, pathMatcher) -> { + String operation = pathMatcher.group(1); + CodeInterpreterController controller = new CodeInterpreterController(context); + + return switch (operation) { + case "open_session" -> controller::openSession; + case "close_session" -> controller::closeSession; + case "execute_code" -> controller::executeCode; + case "upload_file" -> controller::uploadFile; + case "download_file" -> controller::downloadFile; + case "list_files" -> controller::listFiles; + case "transfer_input_file" -> controller::transferInputFile; + case "transfer_output_file" -> controller::transferOutputFile; + default -> null; + }; + }); // DELETE routes delete(PATTERN_FILES, (proxy, context, pathMatcher) -> { ResourceController controller = new ResourceController(proxy, context, false); diff --git a/server/src/main/java/com/epam/aidial/core/server/data/AuthBucket.java b/server/src/main/java/com/epam/aidial/core/server/data/AuthBucket.java new file mode 100644 index 000000000..a434f9e23 --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/data/AuthBucket.java @@ -0,0 +1,31 @@ +package com.epam.aidial.core.server.data; + +import lombok.Data; + +import javax.annotation.Nullable; + + +@Data +public class AuthBucket { + + /** + * The encrypted bucket location for the original JWT or API_KEY. + */ + String userBucket; + /** + * The bucket location for the original JWT or API_KEY. + */ + String userBucketLocation; + + /** + * The encrypted bucket location for the application from PER_REQUEST_KEY if present. + */ + @Nullable + String appBucket; + + /** + * The bucket location for the application from PER_REQUEST_KEY if present. + */ + @Nullable + String appBucketLocation; +} diff --git a/server/src/main/java/com/epam/aidial/core/server/data/ResourceTypes.java b/server/src/main/java/com/epam/aidial/core/server/data/ResourceTypes.java index d0366fee9..3333009eb 100644 --- a/server/src/main/java/com/epam/aidial/core/server/data/ResourceTypes.java +++ b/server/src/main/java/com/epam/aidial/core/server/data/ResourceTypes.java @@ -7,7 +7,8 @@ public enum ResourceTypes implements ResourceType { PROMPT("prompts", true), LIMIT("limits", true), SHARED_WITH_ME("shared_with_me", true), SHARED_BY_ME("shared_by_me", true), INVITATION("invitations", true), PUBLICATION("publications", true), RULES("rules", true), API_KEY_DATA("api_key_data", true), NOTIFICATION("notifications", true), - APPLICATION("applications", true), DEPLOYMENT_COST_STATS("deployment_cost_stats", true); + APPLICATION("applications", true), DEPLOYMENT_COST_STATS("deployment_cost_stats", true), + CODE_INTERPRETER_SESSION("code_interpreter_session", true); private final String group; private final boolean requireCompression; @@ -17,8 +18,6 @@ public enum ResourceTypes implements ResourceType { this.requireCompression = requireCompression; } - - public static ResourceTypes of(String group) { return switch (group) { case "files" -> FILE; diff --git a/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecute.java b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecute.java new file mode 100644 index 000000000..4465e0f08 --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecute.java @@ -0,0 +1,15 @@ +package com.epam.aidial.core.server.data.codeinterpreter; + +import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.Data; + +import java.util.List; + +@Data +@JsonInclude(JsonInclude.Include.NON_NULL) +public class CodeInterpreterExecute { + private String sessionId; + private String code; + private List inputFiles; + private List outputFiles; +} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecution.java b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecution.java new file mode 100644 index 000000000..0251c253c --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecution.java @@ -0,0 +1,14 @@ +package com.epam.aidial.core.server.data.codeinterpreter; + +import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.Data; + +@Data +@JsonInclude(JsonInclude.Include.NON_NULL) +public class CodeInterpreterExecution { + private String status; + private String stdout; + private String stderr; + private Object result; + private Object display; +} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterFile.java b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterFile.java new file mode 100644 index 000000000..09daee6d6 --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterFile.java @@ -0,0 +1,12 @@ +package com.epam.aidial.core.server.data.codeinterpreter; + +import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.Data; + +@Data +@JsonInclude(JsonInclude.Include.NON_NULL) +public class CodeInterpreterFile { + String sessionId; + String path; + Long size; +} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterFiles.java b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterFiles.java new file mode 100644 index 000000000..c9becee99 --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterFiles.java @@ -0,0 +1,12 @@ +package com.epam.aidial.core.server.data.codeinterpreter; + +import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.Data; + +import java.util.List; + +@Data +@JsonInclude(JsonInclude.Include.NON_NULL) +public class CodeInterpreterFiles { + List files; +} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterInputFile.java b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterInputFile.java new file mode 100644 index 000000000..a841a1d23 --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterInputFile.java @@ -0,0 +1,12 @@ +package com.epam.aidial.core.server.data.codeinterpreter; + +import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.Data; + +@Data +@JsonInclude(JsonInclude.Include.NON_NULL) +public class CodeInterpreterInputFile { + String sessionId; + String sourceUrl; + String targetPath; +} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterOutputFile.java b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterOutputFile.java new file mode 100644 index 000000000..74123da2e --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterOutputFile.java @@ -0,0 +1,12 @@ +package com.epam.aidial.core.server.data.codeinterpreter; + +import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.Data; + +@Data +@JsonInclude(JsonInclude.Include.NON_NULL) +public class CodeInterpreterOutputFile { + String sessionId; + String sourcePath; + String targetUrl; +} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterSession.java b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterSession.java new file mode 100644 index 000000000..edb3eaf5c --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterSession.java @@ -0,0 +1,13 @@ +package com.epam.aidial.core.server.data.codeinterpreter; + +import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.Data; + +@Data +@JsonInclude(JsonInclude.Include.NON_NULL) +public class CodeInterpreterSession { + String sessionId; + String deploymentId; + String deploymentUrl; + Long usedAt; +} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterSessionId.java b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterSessionId.java new file mode 100644 index 000000000..bcbc1d1a9 --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterSessionId.java @@ -0,0 +1,10 @@ +package com.epam.aidial.core.server.data.codeinterpreter; + +import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.Data; + +@Data +@JsonInclude(JsonInclude.Include.NON_NULL) +public class CodeInterpreterSessionId { + String sessionId; +} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/service/ApplicationOperatorService.java b/server/src/main/java/com/epam/aidial/core/server/service/ApplicationOperatorService.java index 499a80675..247d0859f 100644 --- a/server/src/main/java/com/epam/aidial/core/server/service/ApplicationOperatorService.java +++ b/server/src/main/java/com/epam/aidial/core/server/service/ApplicationOperatorService.java @@ -7,6 +7,7 @@ import com.epam.aidial.core.storage.http.HttpException; import com.epam.aidial.core.storage.http.HttpStatus; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; import io.vertx.core.http.HttpClient; import io.vertx.core.http.HttpClientRequest; import io.vertx.core.http.HttpHeaders; @@ -28,7 +29,7 @@ /** * A web client to Application Controller Web Service that manages deployments for applications with functions. */ -class ApplicationOperatorService { +public class ApplicationOperatorService { private final HttpClient client; private final String endpoint; @@ -40,11 +41,11 @@ public ApplicationOperatorService(HttpClient client, JsonObject settings) { this.timeout = settings.getLong("controllerTimeout", 240000L); } - boolean isActive() { + public boolean isActive() { return endpoint != null; } - void verifyActive() { + public void verifyActive() { if (!isActive()) { throw new HttpException(HttpStatus.SERVICE_UNAVAILABLE, "The application controller is not available"); } @@ -87,7 +88,6 @@ String createApplicationDeployment(ProxyContext context, Application.Function fu } request.putHeader(HttpHeaders.CONTENT_TYPE, Proxy.HEADER_CONTENT_TYPE_APPLICATION_JSON); - CreateDeploymentRequest body = new CreateDeploymentRequest(function.getEnv()); return ProxyUtil.convertToString(body); }, @@ -114,6 +114,24 @@ Application.Logs getApplicationLogs(Application.Function function) { body -> ProxyUtil.convertToObject(body, Application.Logs.class)); } + public String createCodeInterpreterDeployment(String id, String image) { + CreateDeploymentResponse deployment = callController(HttpMethod.POST, "/v1/deployment/" + id, + request -> { + request.putHeader(HttpHeaders.CONTENT_TYPE, Proxy.HEADER_CONTENT_TYPE_APPLICATION_JSON); + CreateDeploymentRequest body = new CreateDeploymentRequest(image, 1, 1, 1, Map.of()); + return ProxyUtil.convertToString(body); + }, + body -> convertServerSentEvent(body, CreateDeploymentResponse.class)); + + return deployment.url(); + } + + public void deleteCodeInterpreterDeployment(String id) { + callController(HttpMethod.DELETE, "/v1/deployment/" + id, + request -> null, + body -> convertServerSentEvent(body, EmptyResponse.class)); + } + @SneakyThrows private R callController(HttpMethod method, String path, Function requestMapper, @@ -206,7 +224,12 @@ private static T convertServerSentEvent(String body, Class clazz) { private record CreateImageRequest(String runtime, String sources) { } - private record CreateDeploymentRequest(Map env) { + @JsonInclude(JsonInclude.Include.NON_NULL) + private record CreateDeploymentRequest(String image, Integer initialScale, Integer minScale, Integer maxScale, + Map env) { + private CreateDeploymentRequest(Map env) { + this(null, null, null, null, env); + } } @JsonIgnoreProperties(ignoreUnknown = true) diff --git a/server/src/main/java/com/epam/aidial/core/server/service/ApplicationService.java b/server/src/main/java/com/epam/aidial/core/server/service/ApplicationService.java index ec14b47f9..497eb4467 100644 --- a/server/src/main/java/com/epam/aidial/core/server/service/ApplicationService.java +++ b/server/src/main/java/com/epam/aidial/core/server/service/ApplicationService.java @@ -64,11 +64,11 @@ public class ApplicationService { private final boolean includeCustomApps; public ApplicationService(Vertx vertx, - HttpClient httpClient, RedissonClient redis, EncryptionService encryptionService, ResourceService resourceService, LockService lockService, + ApplicationOperatorService operatorService, Supplier idGenerator, JsonObject settings) { String pendingApplicationsKey = BlobStorageUtil.toStoragePath(lockService.getPrefix(), "pending-applications"); @@ -79,7 +79,7 @@ public ApplicationService(Vertx vertx, this.lockService = lockService; this.idGenerator = idGenerator; this.pendingApplications = redis.getScoredSortedSet(pendingApplicationsKey, StringCodec.INSTANCE); - this.controller = new ApplicationOperatorService(httpClient, settings); + this.controller = operatorService; this.checkDelay = settings.getLong("checkDelay", 300000L); this.checkSize = settings.getInteger("checkSize", 64); this.includeCustomApps = settings.getBoolean("includeCustomApps", false); diff --git a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterClient.java b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterClient.java new file mode 100644 index 000000000..cf90016ac --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterClient.java @@ -0,0 +1,117 @@ +package com.epam.aidial.core.server.service.codeinterpreter; + +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecution; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFile; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFiles; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSession; +import com.epam.aidial.core.server.util.ProxyUtil; +import io.vertx.core.Future; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import org.apache.hc.client5.http.classic.HttpClient; +import org.apache.hc.client5.http.classic.methods.HttpPost; +import org.apache.hc.client5.http.config.RequestConfig; +import org.apache.hc.client5.http.entity.mime.MultipartEntityBuilder; +import org.apache.hc.client5.http.impl.classic.HttpClients; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.io.entity.HttpEntities; + +import java.io.InputStream; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +@RequiredArgsConstructor +public class CodeInterpreterClient { + + // Vertx HttpClient does not support multipart upload, Vertx WebClient supports only Buffer as body for multipart upload + private final HttpClient client = HttpClients.createDefault(); + private final long timeout; + + CodeInterpreterExecution executeCode(CodeInterpreterSession session, String code) { + Map body = Map.of("code", code); + return execute(session, "/execute_code", body, CodeInterpreterExecution.class); + } + + CodeInterpreterFiles listFiles(CodeInterpreterSession session) { + Map body = Map.of(); + return execute(session, "/list_files", body, CodeInterpreterFiles.class); + } + + @SneakyThrows + CodeInterpreterFile uploadFile(CodeInterpreterSession session, InputStream source, String target) { + HttpPost post = new HttpPost(session.getDeploymentUrl() + "/upload_file"); + post.setConfig(RequestConfig.custom().setResponseTimeout(timeout, TimeUnit.MILLISECONDS).build()); + post.setEntity(MultipartEntityBuilder.create() + .addBinaryBody("file", source, ContentType.APPLICATION_OCTET_STREAM, target) + .build()); + + return client.execute(post, response -> { + int status = response.getCode(); + String body = EntityUtils.toString(response.getEntity()); + + if (status != 200) { + throw new CodeInterpreterError(status, body); + } + + return ProxyUtil.convertToObject(body, CodeInterpreterFile.class); + }); + } + + @SneakyThrows + R downloadFile(CodeInterpreterSession session, String path, DownloadFileFunction consumer) { + HttpPost post = new HttpPost(session.getDeploymentUrl() + "/download_file"); + post.setConfig(RequestConfig.custom().setResponseTimeout(timeout, TimeUnit.MILLISECONDS).build()); + post.setEntity(HttpEntities.create(ProxyUtil.convertToString(Map.of("path", path)), ContentType.APPLICATION_JSON)); + + return client.execute(post, response -> { + int status = response.getCode(); + HttpEntity entity = response.getEntity(); + + if (status != 200) { + String body = EntityUtils.toString(entity); + throw new CodeInterpreterError(status, body); + } + + try { + CompletableFuture result = new CompletableFuture<>(); + long size = Long.parseLong(response.getHeader(HttpHeaders.CONTENT_LENGTH).getValue()); + InputStream stream = entity.getContent(); + + consumer.apply(stream, size) + .onSuccess(result::complete) + .onFailure(result::completeExceptionally); + + return result.get(timeout, TimeUnit.MILLISECONDS); + } catch (Throwable e) { + EntityUtils.consumeQuietly(entity); + throw new CodeInterpreterError(500, "Failed to download file: " + path); + } + }); + } + + @SneakyThrows + private R execute(CodeInterpreterSession session, String path, Object requestPayload, Class responseType) { + HttpPost post = new HttpPost(session.getDeploymentUrl() + path); + post.setConfig(RequestConfig.custom().setResponseTimeout(timeout, TimeUnit.MILLISECONDS).build()); + post.setEntity(HttpEntities.create(ProxyUtil.convertToString(requestPayload), ContentType.APPLICATION_JSON)); + + return client.execute(post, response -> { + int status = response.getCode(); + String body = EntityUtils.toString(response.getEntity()); + + if (status != 200) { + throw new CodeInterpreterError(status, body); + } + + return ProxyUtil.convertToObject(body, responseType); + }); + } + + public interface DownloadFileFunction { + Future apply(InputStream stream, long size) throws Throwable; + } +} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterError.java b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterError.java new file mode 100644 index 000000000..1fe97b0e0 --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterError.java @@ -0,0 +1,13 @@ +package com.epam.aidial.core.server.service.codeinterpreter; + +import lombok.Getter; + +@Getter +public class CodeInterpreterError extends RuntimeException { + private final int status; + + public CodeInterpreterError(int status, String body) { + super(body); + this.status = status; + } +} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterService.java b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterService.java new file mode 100644 index 000000000..3c3a6e4a1 --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterService.java @@ -0,0 +1,380 @@ +package com.epam.aidial.core.server.service.codeinterpreter; + +import com.epam.aidial.core.server.ProxyContext; +import com.epam.aidial.core.server.data.AuthBucket; +import com.epam.aidial.core.server.data.ResourceTypes; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecute; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecution; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFile; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFiles; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterInputFile; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterOutputFile; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSession; +import com.epam.aidial.core.server.security.AccessService; +import com.epam.aidial.core.server.security.EncryptionService; +import com.epam.aidial.core.server.service.ApplicationOperatorService; +import com.epam.aidial.core.server.service.PermissionDeniedException; +import com.epam.aidial.core.server.service.ResourceNotFoundException; +import com.epam.aidial.core.server.util.BucketBuilder; +import com.epam.aidial.core.server.util.ResourceDescriptorFactory; +import com.epam.aidial.core.server.vertx.stream.BlobWriteStream; +import com.epam.aidial.core.server.vertx.stream.InputStreamReader; +import com.epam.aidial.core.storage.blobstore.BlobStorageUtil; +import com.epam.aidial.core.storage.data.FileMetadata; +import com.epam.aidial.core.storage.http.HttpException; +import com.epam.aidial.core.storage.http.HttpStatus; +import com.epam.aidial.core.storage.resource.ResourceDescriptor; +import com.epam.aidial.core.storage.service.LockService; +import com.epam.aidial.core.storage.service.ResourceService; +import com.epam.aidial.core.storage.util.EtagHeader; +import io.vertx.core.Vertx; +import io.vertx.core.json.JsonObject; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import org.redisson.api.RScoredSortedSet; +import org.redisson.api.RedissonClient; + +import java.io.InputStream; +import java.nio.file.Path; +import java.util.Objects; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import static com.epam.aidial.core.server.util.ProxyUtil.convertToObject; +import static com.epam.aidial.core.server.util.ProxyUtil.convertToString; + +@Slf4j +public class CodeInterpreterService { + + private final Vertx vertx; + private final ResourceService resourceService; + private final AccessService accessService; + private final EncryptionService encryptionService; + private final ApplicationOperatorService operatorService; + private final RScoredSortedSet activeSessions; + private final CodeInterpreterClient client; + private final Supplier idGenerator; + private final String sessionImage; + private final long sessionTtl; + private final int checkSize; + + public CodeInterpreterService(Vertx vertx, RedissonClient redisson, + ResourceService resourceService, AccessService accessService, + EncryptionService encryptionService, ApplicationOperatorService operatorService, + Supplier idGenerator, JsonObject settings) { + String activeSessionsKey = BlobStorageUtil.toStoragePath(resourceService.getPrefix(), "active-code-interpreter-sessions"); + + this.vertx = vertx; + this.resourceService = resourceService; + this.accessService = accessService; + this.encryptionService = encryptionService; + this.operatorService = operatorService; + this.idGenerator = idGenerator; + this.activeSessions = redisson.getScoredSortedSet(activeSessionsKey); + this.sessionImage = settings.getString("sessionImage"); + this.sessionTtl = settings.getLong("sessionTtl", 600000L); + this.checkSize = settings.getInteger("checkSize", 256); + this.client = new CodeInterpreterClient(sessionTtl); + + if (isActive()) { + long checkPeriod = settings.getLong("checkPeriod", 10000L); + vertx.setPeriodic(checkPeriod, checkPeriod, ignore -> vertx.executeBlocking(this::checkSessions)); + } + } + + private Void checkSessions() { + log.debug("Checking active sessions"); + try { + long now = System.currentTimeMillis(); + + for (String url : activeSessions.valueRange(Double.NEGATIVE_INFINITY, true, now, true, 0, checkSize)) { + log.debug("Checking active session: {}", url); + ResourceDescriptor resource = ResourceDescriptorFactory.fromAnyUrl(url, encryptionService); + Predicate ifExpired = session -> System.currentTimeMillis() - session.getUsedAt() >= sessionTtl; + cleanupSession(resource, ifExpired); + } + } catch (Throwable e) { + log.warn("Failed to check active sessions", e); + } + + return null; + } + + private void cleanupSession(ResourceDescriptor resource, Predicate predicate) { + try (LockService.Lock lock = resourceService.tryLockResource(resource)) { + if (lock == null) { + return; + } + + String json = resourceService.getResource(resource, EtagHeader.ANY, false); + CodeInterpreterSession session = convertToObject(json, CodeInterpreterSession.class); + + if (session != null && predicate.test(session)) { + operatorService.deleteCodeInterpreterDeployment(session.getDeploymentId()); + resourceService.deleteResource(resource, EtagHeader.ANY, false); + session = null; + } + + if (session == null) { + activeSessions.remove(resource.getUrl()); + } + } catch (Throwable e) { + log.warn("Failed to cleanup active session", e); + } + } + + public CodeInterpreterSession touchSession(ProxyContext context, String sessionId) { + verifyActive(); + verifySessionId(sessionId); + + ResourceDescriptor resource = sessionResource(context, sessionId); + try (LockService.Lock lock = resourceService.lockResource(resource)) { + String json = resourceService.getResource(resource, EtagHeader.ANY, false); + CodeInterpreterSession session = convertToObject(json, CodeInterpreterSession.class); + + if (session == null) { + throw new ResourceNotFoundException("Session is not found: " + sessionId); + } + + if (session.getDeploymentUrl() == null) { + throw new IllegalStateException("Session is not yet initialized: " + sessionId); + } + + session.setUsedAt(System.currentTimeMillis()); + activeSessions.add(session.getUsedAt() + sessionTtl, resource.getUrl()); + resourceService.putResource(resource, convertToString(session), EtagHeader.ANY, false); + return session; + } + } + + public CodeInterpreterSession openSession(ProxyContext context, String sessionId) { + verifyActive(); + + if (sessionId == null) { + sessionId = idGenerator.get(); + } + + ResourceDescriptor resource = sessionResource(context, sessionId); + CodeInterpreterSession session = new CodeInterpreterSession(); + session.setSessionId(sessionId); + session.setDeploymentId(idGenerator.get()); + boolean cleanup = false; + + try (LockService.Lock lock = resourceService.lockResource(resource)) { + String json = resourceService.getResource(resource, EtagHeader.ANY, false); + CodeInterpreterSession existing = convertToObject(json, CodeInterpreterSession.class); + if (existing != null) { + throw new IllegalArgumentException("Session already exists: " + session.getSessionId()); + } + + cleanup = true; + session.setUsedAt(System.currentTimeMillis()); + activeSessions.add(session.getUsedAt() + sessionTtl, resource.getUrl()); + resourceService.putResource(resource, convertToString(session), EtagHeader.ANY, false); + + String deploymentUrl = operatorService.createCodeInterpreterDeployment(session.getDeploymentId(), sessionImage); + session.setDeploymentUrl(deploymentUrl); + session.setUsedAt(System.currentTimeMillis()); + + activeSessions.add(session.getUsedAt() + sessionTtl, resource.getUrl()); + resourceService.putResource(resource, convertToString(session), EtagHeader.ANY, false); + } catch (Throwable error) { + if (cleanup) { + Predicate ifMatch = candidate -> Objects.equals(candidate.getDeploymentId(), session.getDeploymentId()); + cleanupSession(resource, ifMatch); + } + + throw error; + } + + return session; + } + + public CodeInterpreterSession closeSession(ProxyContext context, String sessionId) { + verifyActive(); + verifySessionId(sessionId); + + ResourceDescriptor resource = sessionResource(context, sessionId); + try (LockService.Lock lock = resourceService.lockResource(resource)) { + String json = resourceService.getResource(resource, EtagHeader.ANY, false); + CodeInterpreterSession session = convertToObject(json, CodeInterpreterSession.class); + + if (session == null) { + throw new ResourceNotFoundException("Session is not found: " + sessionId); + } + + operatorService.deleteCodeInterpreterDeployment(session.getDeploymentId()); + resourceService.deleteResource(resource, EtagHeader.ANY, false); + activeSessions.remove(resource.getUrl()); + return session; + } + } + + public CodeInterpreterExecution executeCode(ProxyContext context, CodeInterpreterExecute request) { + verifyActive(); + verifyCode(request); + + boolean anonymous = (request.getSessionId() == null); + CodeInterpreterSession session; + + if (anonymous) { + session = openSession(context, null); + } else { + session = touchSession(context, request.getSessionId()); + } + + if (request.getInputFiles() != null) { + for (CodeInterpreterInputFile input : request.getInputFiles()) { + input.setSessionId(session.getSessionId()); + transferInputFile(context, input); + } + } + + CodeInterpreterExecution response = client.executeCode(session, request.getCode()); + + if (request.getOutputFiles() != null) { + for (CodeInterpreterOutputFile output : request.getOutputFiles()) { + output.setSessionId(session.getSessionId()); + transferOutputFile(context, output); + } + } + + if (anonymous) { + closeSession(context, session.getSessionId()); + } + + return response; + } + + @SneakyThrows + public CodeInterpreterFile uploadFile(ProxyContext context, String sessionId, String path, InputStream stream) { + try (InputStream resource = stream) { + verifyActive(); + verifySessionId(sessionId); + verifyPath(path); + + CodeInterpreterSession session = touchSession(context, sessionId); + return client.uploadFile(session, stream, path); + } + } + + public R downloadFile(ProxyContext context, String sessionId, String path, CodeInterpreterClient.DownloadFileFunction function) { + verifyActive(); + verifySessionId(sessionId); + verifyPath(path); + + CodeInterpreterSession session = touchSession(context, sessionId); + return client.downloadFile(session, path, function); + } + + public CodeInterpreterFiles listFiles(ProxyContext context, String sessionId) { + verifyActive(); + verifySessionId(sessionId); + + CodeInterpreterSession session = touchSession(context, sessionId); + return client.listFiles(session); + } + + @SneakyThrows + public CodeInterpreterFile transferInputFile(ProxyContext context, CodeInterpreterInputFile file) { + verifyActive(); + verifySessionId(file.getSessionId()); + verifyPath(file.getTargetPath()); + + ResourceDescriptor resource = verifyFile(context, file.getSourceUrl(), true); + ResourceService.ResourceStream input = resourceService.getResourceStream(resource, EtagHeader.ANY); + + if (input == null) { + throw new ResourceNotFoundException("File is not found: " + resource.getUrl()); + } + + return uploadFile(context, file.getSessionId(), file.getTargetPath(), input.inputStream()); + } + + public FileMetadata transferOutputFile(ProxyContext context, CodeInterpreterOutputFile file) { + verifyActive(); + verifySessionId(file.getSessionId()); + verifyPath(file.getSourcePath()); + + ResourceDescriptor resource = verifyFile(context, file.getTargetUrl(), false); + + return downloadFile(context, file.getSessionId(), file.getSourcePath(), (input, size) -> { + BlobWriteStream output = new BlobWriteStream(vertx, resourceService, + context.getProxy().getStorage(), resource, EtagHeader.ANY, null); + + return new InputStreamReader(vertx, input) + .pipe() + .endOnFailure(false) + .to(output) + .onFailure(output::abortUpload) + .map(success -> output.getMetadata()); + }); + } + + private ResourceDescriptor sessionResource(ProxyContext context, String sessionId) { + AuthBucket bucket = BucketBuilder.buildBucket(context); + + try { + String path = (bucket.getAppBucket() == null) + ? ("user/" + sessionId) + : ("app/" + bucket.getAppBucket() + "/" + sessionId); + + return ResourceDescriptorFactory.fromEncoded(ResourceTypes.CODE_INTERPRETER_SESSION, + bucket.getUserBucket(), bucket.getUserBucketLocation(), path); + } catch (Throwable e) { + throw new IllegalArgumentException("Invalid sessionId: " + sessionId); + } + } + + private boolean isActive() { + return sessionImage != null && operatorService.isActive(); + } + + private void verifyActive() { + if (!isActive()) { + throw new HttpException(HttpStatus.SERVICE_UNAVAILABLE, "Code interpreter is not available"); + } + } + + private ResourceDescriptor verifyFile(ProxyContext context, String url, boolean input) { + ResourceDescriptor resource; + try { + resource = ResourceDescriptorFactory.fromAnyUrl(url, encryptionService); + if (resource.getType() != ResourceTypes.FILE) { + throw new IllegalArgumentException(); + } + } catch (Throwable e) { + throw new IllegalArgumentException("Bad file url:" + url); + } + + boolean isAccessible = input + ? accessService.hasReadAccess(resource, context) + : accessService.hasWriteAccess(resource, context); + + if (!isAccessible) { + throw new PermissionDeniedException("File is not accessible: " + resource.getUrl()); + } + + return resource; + } + + private static void verifySessionId(String sessionId) { + if (sessionId == null) { + throw new IllegalArgumentException("Missing sessionId"); + } + } + + private static void verifyCode(CodeInterpreterExecute request) { + if (request.getCode() == null) { + throw new IllegalArgumentException("Missing code"); + } + } + + private static void verifyPath(String path) { + try { + Path ignore = Path.of(path); + } catch (Throwable e) { + throw new IllegalArgumentException("Bad file path:" + path); + } + } +} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/util/BucketBuilder.java b/server/src/main/java/com/epam/aidial/core/server/util/BucketBuilder.java index 2c605d379..73d8be0d8 100644 --- a/server/src/main/java/com/epam/aidial/core/server/util/BucketBuilder.java +++ b/server/src/main/java/com/epam/aidial/core/server/util/BucketBuilder.java @@ -1,8 +1,11 @@ package com.epam.aidial.core.server.util; import com.epam.aidial.core.server.ProxyContext; +import com.epam.aidial.core.server.data.AuthBucket; +import com.epam.aidial.core.server.security.EncryptionService; import lombok.experimental.UtilityClass; +import java.util.Objects; import javax.annotation.Nullable; @UtilityClass @@ -44,4 +47,27 @@ public static String buildInitiatorBucket(ProxyContext context) { throw new IllegalArgumentException("Can't find user bucket. Either user sub or api-key project must be provided"); } + public AuthBucket buildBucket(ProxyContext context) { + EncryptionService encryption = context.getProxy().getEncryptionService(); + String perRequestKey = context.getApiKeyData().getPerRequestKey(); + + String userBucketLocation = buildInitiatorBucket(context); + String userBucket = encryption.encrypt(userBucketLocation); + + String appBucket = null; + String appBucketLocation = null; + + if (perRequestKey != null) { + Objects.requireNonNull(context.getSourceDeployment()); + appBucketLocation = API_KEY_BUCKET_PATTERN.formatted(context.getSourceDeployment()); + appBucket = encryption.encrypt(appBucketLocation); + } + + AuthBucket bucket = new AuthBucket(); + bucket.setUserBucket(userBucket); + bucket.setUserBucketLocation(userBucketLocation); + bucket.setAppBucket(appBucket); + bucket.setAppBucketLocation(appBucketLocation); + return bucket; + } } diff --git a/server/src/main/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapter.java b/server/src/main/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapter.java new file mode 100644 index 000000000..8b719ca84 --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapter.java @@ -0,0 +1,122 @@ +package com.epam.aidial.core.server.vertx.stream; + +import io.vertx.core.buffer.Buffer; +import io.vertx.core.streams.ReadStream; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Objects; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicLong; + +public class InputStreamAdapter extends InputStream { + + private static final Buffer END_PILL = Buffer.buffer(); + private static final int LOW_MEMORY_BYTES = 1024 * 1024; + private static final int HIGH_MEMORY_BYTES = 4 * 1024 * 1024; + + private final BlockingQueue queue = new LinkedBlockingQueue<>(); + private final AtomicLong memory = new AtomicLong(); + private final ReadStream stream; + + private volatile IOException error; + private Buffer current; + private int position; + + public InputStreamAdapter(ReadStream stream) { + this.stream = stream; + stream.handler(this::onData) + .endHandler(this::onEnd) + .exceptionHandler(this::onError); + } + + private void onData(Buffer data) { + if (data.length() > 0 && error == null) { + queue.add(data); + update(data.length()); + } + } + + private void onEnd(Void data) { + queue.add(END_PILL); + } + + private void onError(Throwable exception) { + error = new IOException(exception); + queue.add(END_PILL); + } + + @Override + public int read() throws IOException { + // so dumb - not really used + byte[] array = new byte[1]; + int size = read(array, 0, 1); + return (size <= 0) ? -1 : (array[0] & 0xFF); + } + + @Override + public synchronized int read(byte[] array, int offset, int length) throws IOException { + Objects.checkFromIndexSize(offset, length, array.length); + + if (error != null) { + throw error; + } + + if (current == END_PILL) { + return -1; + } + + if (length == 0) { + return 0; + } + + try { + int size = 0; + + while (size < length) { + if (current == null) { + current = queue.take(); + + if (error != null) { + throw error; + } + + if (current == END_PILL) { + break; + } + } + + int chunk = Math.min(length - size, current.length() - position); + current.getBytes(position, position + chunk, array, offset + size); + position += chunk; + size += chunk; + + if (position == current.length()) { + update(-current.length()); + current = null; + position = 0; + } + } + + return size == 0 ? -1 : size; + } catch (InterruptedException e) { + error = new IOException(e); + throw new IOException(error); + } + } + + private void update(long delta) { + long footprint = memory.addAndGet(delta); + if (footprint <= LOW_MEMORY_BYTES) { + stream.fetch(HIGH_MEMORY_BYTES - footprint); + } else if (footprint >= HIGH_MEMORY_BYTES) { + stream.pause(); + } + } + + @Override + public synchronized void close() { + error = new IOException("closed"); + } +} diff --git a/server/src/test/java/com/epam/aidial/core/server/CodeInterpreterApiTest.java b/server/src/test/java/com/epam/aidial/core/server/CodeInterpreterApiTest.java new file mode 100644 index 000000000..b180b732c --- /dev/null +++ b/server/src/test/java/com/epam/aidial/core/server/CodeInterpreterApiTest.java @@ -0,0 +1,132 @@ +package com.epam.aidial.core.server; + +import io.vertx.core.http.HttpMethod; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class CodeInterpreterApiTest extends ResourceBaseTest { + + private TestWebServer webServer; + + @BeforeEach + void initWebServer() { + webServer = new TestWebServer(17321); + } + + @AfterEach + void destroyDeploymentService() { + try (TestWebServer server = webServer) { + // closing + } + } + + @Test + void testStatefulWorkflow() { + webServer.map(HttpMethod.POST, "/v1/deployment/0124", 200, """ + event: result + data: {"url":"http://localhost:17321"}"""); + Response response = send(HttpMethod.POST, "/v1/ops/code_interpreter/open_session", null, """ + {}"""); + verifyJson(response, 200, """ + {"sessionId":"0123"}"""); + + webServer.map(HttpMethod.POST, "/execute_code", 200, """ + {"status":"SUCCESS","stdout":"","stderr":"","result":{"text/plain":"3"},"display":[]}"""); + response = send(HttpMethod.POST, "/v1/ops/code_interpreter/execute_code", null, """ + {"sessionId":"0123","code":"1+2"}"""); + verifyJson(response, 200, """ + {"status":"SUCCESS","stdout":"","stderr":"","result":{"text/plain":"3"},"display":[]}"""); + + webServer.map(HttpMethod.POST, "/list_files", 200, """ + {"files": []}"""); + response = send(HttpMethod.POST, "/v1/ops/code_interpreter/list_files", null, """ + {"sessionId":"0123"} + """); + verifyJson(response, 200, """ + {"files": []}"""); + + String content = "1".repeat(16 * 1024 * 1024); + webServer.map(HttpMethod.POST, "/upload_file", 200, """ + {"path": "/mnt/data/file.txt","size": 16777216}"""); + response = upload(HttpMethod.POST, "/v1/ops/code_interpreter/upload_file", "session_id=0123", content); + verifyJson(response, 200, """ + {"path": "/mnt/data/file.txt","size": 16777216}"""); + + webServer.map(HttpMethod.POST, "/download_file", 200, content); + response = send(HttpMethod.POST, "/v1/ops/code_interpreter/download_file", null, """ + {"sessionId":"0123","path":"file.txt"}"""); + verify(response, 200, content); + + webServer.map(HttpMethod.POST, "/list_files", 200, """ + {"files": [{"path": "/mnt/data/file.txt","size": 16777216}]}"""); + response = send(HttpMethod.POST, "/v1/ops/code_interpreter/list_files", null, """ + {"sessionId":"0123"}"""); + verifyJson(response, 200, """ + {"files": [{"path": "/mnt/data/file.txt","size": 16777216}]}"""); + + content += "2"; + upload(HttpMethod.PUT, "/v1/files/3CcedGxCx23EwiVbVmscVktScRyf46KypuBQ65miviST/file2.txt", null, content); + verify(response, 200); + + webServer.map(HttpMethod.POST, "/upload_file", 200, """ + {"path": "/mnt/data/file2.txt","size": 16777217}"""); + response = send(HttpMethod.POST, "/v1/ops/code_interpreter/transfer_input_file", null, """ + {"sessionId":"0123","sourceUrl":"files/3CcedGxCx23EwiVbVmscVktScRyf46KypuBQ65miviST/file2.txt","targetPath":"file2.txt"}"""); + verifyJson(response, 200, """ + {"path": "/mnt/data/file2.txt","size": 16777217}"""); + + webServer.map(HttpMethod.POST, "/download_file", 200, content); + response = send(HttpMethod.POST, "/v1/ops/code_interpreter/transfer_output_file", null, """ + {"sessionId":"0123","sourcePath":"file2.txt","targetUrl":"files/3CcedGxCx23EwiVbVmscVktScRyf46KypuBQ65miviST/file3.txt"}"""); + verify(response, 200); + + response = send(HttpMethod.GET, "/v1/files/3CcedGxCx23EwiVbVmscVktScRyf46KypuBQ65miviST/file3.txt", null, ""); + verify(response, 200, content); + + webServer.map(HttpMethod.DELETE, "/v1/deployment/0124", 200, """ + event: result + data: {"deleted":true}"""); + response = send(HttpMethod.POST, "/v1/ops/code_interpreter/close_session", null, """ + {"sessionId":"0123"}"""); + verifyJson(response, 200, """ + {"sessionId":"0123"}"""); + } + + @Test + void testStatelessWorkflow() { + String inputContent = "1".repeat(1024); + String outputContent = "2".repeat(2048); + + upload(HttpMethod.PUT, "/v1/files/3CcedGxCx23EwiVbVmscVktScRyf46KypuBQ65miviST/input-file.txt", null, inputContent); + + webServer.map(HttpMethod.POST, "/v1/deployment/0124", 200, """ + event: result + data: {"url":"http://localhost:17321"}"""); + + webServer.map(HttpMethod.POST, "/upload_file", 200, """ + {"path": "/mnt/data/input-file.txt","size": 1024}"""); + + webServer.map(HttpMethod.POST, "/execute_code", 200, """ + {"status":"SUCCESS","stdout":"","stderr":"","result":{"text/plain":"3"},"display":[]}"""); + + webServer.map(HttpMethod.POST, "/download_file", 200, outputContent); + + webServer.map(HttpMethod.DELETE, "/v1/deployment/0124", 200, """ + event: result + data: {"deleted":true}"""); + + + Response response = send(HttpMethod.POST, "/v1/ops/code_interpreter/execute_code", null, """ + { + "code":"1+2", + "inputFiles":[{"sourceUrl":"files/3CcedGxCx23EwiVbVmscVktScRyf46KypuBQ65miviST/input-file.txt","targetPath":"input-file.txt"}], + "outputFiles":[{"targetUrl":"files/3CcedGxCx23EwiVbVmscVktScRyf46KypuBQ65miviST/output-file.txt","sourcePath":"output-file.txt"}] + }"""); + verifyJson(response, 200, """ + {"status":"SUCCESS","stdout":"","stderr":"","result":{"text/plain":"3"},"display":[]}"""); + + response = send(HttpMethod.GET, "/v1/files/3CcedGxCx23EwiVbVmscVktScRyf46KypuBQ65miviST/output-file.txt", null, ""); + verify(response, 200, outputContent); + } +} diff --git a/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java b/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java index 6e2e6ad1e..d459f85b9 100644 --- a/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java @@ -157,6 +157,9 @@ void init() throws Exception { "controllerEndpoint": "http://localhost:17321", "checkDelay": 1000, "checkPeriod": 1000 + }, + "codeInterpreter" : { + "sessionImage": "fake.image" } } """.formatted(Json.encode(testDir.toString())); diff --git a/storage/src/main/java/com/epam/aidial/core/storage/service/ResourceService.java b/storage/src/main/java/com/epam/aidial/core/storage/service/ResourceService.java index 8712cec4e..4b4377b73 100644 --- a/storage/src/main/java/com/epam/aidial/core/storage/service/ResourceService.java +++ b/storage/src/main/java/com/epam/aidial/core/storage/service/ResourceService.java @@ -132,6 +132,20 @@ public ResourceTopic.Subscription subscribeResources(Collection result = getResourceWithMetadata(descriptor, etag, lock); return (result == null) ? null : result.getRight(); } @@ -367,7 +381,7 @@ public ResourceItemMetadata putResource( return putResource(descriptor, body, etag, true); } - private ResourceItemMetadata putResource( + public ResourceItemMetadata putResource( ResourceDescriptor descriptor, String body, EtagHeader etag, boolean lock) { byte[] bytes = body.getBytes(StandardCharsets.UTF_8); return putResource(descriptor, bytes, etag, "application/json", lock); @@ -487,7 +501,7 @@ public boolean deleteResource(ResourceDescriptor descriptor, EtagHeader etag) { return deleteResource(descriptor, etag, true); } - private boolean deleteResource(ResourceDescriptor descriptor, EtagHeader etag, boolean lock) { + public boolean deleteResource(ResourceDescriptor descriptor, EtagHeader etag, boolean lock) { String redisKey = redisKey(descriptor); try (var ignore = lock ? lockService.lock(redisKey) : null) { From 4b85f8efb4b7b5f59b2a688b60984c95fbd4435e Mon Sep 17 00:00:00 2001 From: Artsiom Korzun Date: Fri, 10 Jan 2025 13:05:11 +0100 Subject: [PATCH 2/5] address review comments --- .../epam/aidial/core/server/ProxyContext.java | 8 +- .../controller/CodeInterpreterController.java | 8 +- .../CodeInterpreterClient.java | 10 +- .../codeinterpreter/CodeInterpreterError.java | 13 -- .../vertx/stream/InputStreamAdapter.java | 27 ++-- .../vertx/stream/InputStreamAdapterTest.java | 139 ++++++++++++++++++ .../core/storage/http/HttpException.java | 4 + .../aidial/core/storage/http/HttpStatus.java | 6 +- 8 files changed, 177 insertions(+), 38 deletions(-) delete mode 100644 server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterError.java create mode 100644 server/src/test/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapterTest.java diff --git a/server/src/main/java/com/epam/aidial/core/server/ProxyContext.java b/server/src/main/java/com/epam/aidial/core/server/ProxyContext.java index e250db99c..47525ccf6 100644 --- a/server/src/main/java/com/epam/aidial/core/server/ProxyContext.java +++ b/server/src/main/java/com/epam/aidial/core/server/ProxyContext.java @@ -146,20 +146,16 @@ public Future respond(HttpStatus status, String contentType, Object object) { } public Future respond(HttpStatus status, String body) { - return respond(status.getCode(), body); - } - - public Future respond(int status, String body) { if (body == null) { body = ""; } - if (status != HttpStatus.OK.getCode()) { + if (status != HttpStatus.OK) { log.warn("Responding with error. Project: {}. Trace: {}. Span: {}. Status: {}. Body: {}", getProject(), traceId, spanId, status, body.length() > LOG_MAX_ERROR_LENGTH ? body.substring(0, LOG_MAX_ERROR_LENGTH) : body); } - response.setStatusCode(status).end(body); + response.setStatusCode(status.getCode()).end(body); return Future.succeededFuture(); } diff --git a/server/src/main/java/com/epam/aidial/core/server/controller/CodeInterpreterController.java b/server/src/main/java/com/epam/aidial/core/server/controller/CodeInterpreterController.java index a2b5effc9..aad1de405 100644 --- a/server/src/main/java/com/epam/aidial/core/server/controller/CodeInterpreterController.java +++ b/server/src/main/java/com/epam/aidial/core/server/controller/CodeInterpreterController.java @@ -9,7 +9,6 @@ import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSessionId; import com.epam.aidial.core.server.service.PermissionDeniedException; import com.epam.aidial.core.server.service.ResourceNotFoundException; -import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterError; import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService; import com.epam.aidial.core.server.util.ProxyUtil; import com.epam.aidial.core.server.vertx.stream.InputStreamAdapter; @@ -84,9 +83,10 @@ Future uploadFile() { .setExpectMultipart(true) .uploadHandler(upload -> { // do not move inside execute blocking, otherwise you can miss the beginning of file - InputStream stream = new InputStreamAdapter(upload); + InputStreamAdapter stream = new InputStreamAdapter(upload); vertx.executeBlocking(() -> uploadFile(upload, stream), false) .onSuccess(this::respondJson) + .onComplete(e -> stream.close()) .onFailure(this::respondError); }); @@ -182,6 +182,8 @@ private void respondJson(Object data) { private void respondError(Throwable error) { HttpServerResponse response = context.getResponse(); if (response.headWritten()) { + // download request can partially fail, when some data already is sent, it is too late to send response + // so the only option is to disconnect client response.reset(); } else if (error instanceof IllegalArgumentException) { context.respond(HttpStatus.BAD_REQUEST, error.getMessage()); @@ -191,8 +193,6 @@ private void respondError(Throwable error) { context.respond(HttpStatus.NOT_FOUND, error.getMessage()); } else if (error instanceof HttpException e) { context.respond(e.getStatus(), e.getMessage()); - } else if (error instanceof CodeInterpreterError e) { - context.respond(e.getStatus(), e.getMessage()); } else { log.error("Failed to handle code interpreter request", error); context.respond(error, "Internal error"); diff --git a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterClient.java b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterClient.java index cf90016ac..fdbe83c7e 100644 --- a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterClient.java +++ b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterClient.java @@ -5,6 +5,8 @@ import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFiles; import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSession; import com.epam.aidial.core.server.util.ProxyUtil; +import com.epam.aidial.core.storage.http.HttpException; +import com.epam.aidial.core.storage.http.HttpStatus; import io.vertx.core.Future; import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; @@ -54,7 +56,7 @@ CodeInterpreterFile uploadFile(CodeInterpreterSession session, InputStream sourc String body = EntityUtils.toString(response.getEntity()); if (status != 200) { - throw new CodeInterpreterError(status, body); + throw new HttpException(status, body); } return ProxyUtil.convertToObject(body, CodeInterpreterFile.class); @@ -73,7 +75,7 @@ R downloadFile(CodeInterpreterSession session, String path, DownloadFileFunc if (status != 200) { String body = EntityUtils.toString(entity); - throw new CodeInterpreterError(status, body); + throw new HttpException(status, body); } try { @@ -88,7 +90,7 @@ R downloadFile(CodeInterpreterSession session, String path, DownloadFileFunc return result.get(timeout, TimeUnit.MILLISECONDS); } catch (Throwable e) { EntityUtils.consumeQuietly(entity); - throw new CodeInterpreterError(500, "Failed to download file: " + path); + throw new HttpException(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to download file: " + path); } }); } @@ -104,7 +106,7 @@ private R execute(CodeInterpreterSession session, String path, Object reques String body = EntityUtils.toString(response.getEntity()); if (status != 200) { - throw new CodeInterpreterError(status, body); + throw new HttpException(status, body); } return ProxyUtil.convertToObject(body, responseType); diff --git a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterError.java b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterError.java deleted file mode 100644 index 1fe97b0e0..000000000 --- a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterError.java +++ /dev/null @@ -1,13 +0,0 @@ -package com.epam.aidial.core.server.service.codeinterpreter; - -import lombok.Getter; - -@Getter -public class CodeInterpreterError extends RuntimeException { - private final int status; - - public CodeInterpreterError(int status, String body) { - super(body); - this.status = status; - } -} \ No newline at end of file diff --git a/server/src/main/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapter.java b/server/src/main/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapter.java index 8b719ca84..1250f2112 100644 --- a/server/src/main/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapter.java +++ b/server/src/main/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapter.java @@ -17,7 +17,7 @@ public class InputStreamAdapter extends InputStream { private static final int HIGH_MEMORY_BYTES = 4 * 1024 * 1024; private final BlockingQueue queue = new LinkedBlockingQueue<>(); - private final AtomicLong memory = new AtomicLong(); + private final AtomicLong queuedMemorySize = new AtomicLong(); private final ReadStream stream; private volatile IOException error; @@ -39,12 +39,16 @@ private void onData(Buffer data) { } private void onEnd(Void data) { - queue.add(END_PILL); + if (error == null) { + queue.add(END_PILL); + } } private void onError(Throwable exception) { - error = new IOException(exception); - queue.add(END_PILL); + if (error == null) { + error = new IOException(exception); + queue.add(END_PILL); + } } @Override @@ -107,16 +111,19 @@ public synchronized int read(byte[] array, int offset, int length) throws IOExce } private void update(long delta) { - long footprint = memory.addAndGet(delta); - if (footprint <= LOW_MEMORY_BYTES) { - stream.fetch(HIGH_MEMORY_BYTES - footprint); - } else if (footprint >= HIGH_MEMORY_BYTES) { + long size = queuedMemorySize.addAndGet(delta); + if (size <= LOW_MEMORY_BYTES) { + stream.fetch(HIGH_MEMORY_BYTES - size); + } else if (size >= HIGH_MEMORY_BYTES) { stream.pause(); } } @Override - public synchronized void close() { - error = new IOException("closed"); + public void close() { + if (error == null) { + error = new IOException("closed"); + queue.add(END_PILL); + } } } diff --git a/server/src/test/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapterTest.java b/server/src/test/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapterTest.java new file mode 100644 index 000000000..2e3d1f6b9 --- /dev/null +++ b/server/src/test/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapterTest.java @@ -0,0 +1,139 @@ +package com.epam.aidial.core.server.vertx.stream; + +import io.vertx.core.Handler; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.streams.ReadStream; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + + +class InputStreamAdapterTest { + + protected static final String TEXT = """ + Line1 + Line2 + Line3 + Line4 + Line5 + """; + + @Test + void testReadOneChunk() throws Exception { + TestReadStream source = new TestReadStream(); + InputStreamAdapter stream = new InputStreamAdapter(source); + + source.append(TEXT).end(); + Assertions.assertEquals(TEXT, new String(stream.readAllBytes(), StandardCharsets.UTF_8)); + Assertions.assertEquals(-1, stream.read()); + + stream.close(); + Assertions.assertThrows(IOException.class, stream::read); + } + + @Test + void testReadManyChunks() throws Exception { + TestReadStream source = new TestReadStream(); + InputStreamAdapter stream = new InputStreamAdapter(source); + + for (String chunk : TEXT.split("\n")) { + source.append(chunk).append("\n"); + } + + source.end(); + Assertions.assertEquals(TEXT, new String(stream.readAllBytes(), StandardCharsets.UTF_8)); + Assertions.assertEquals(-1, stream.read()); + + stream.close(); + Assertions.assertThrows(IOException.class, stream::read); + } + + @Test + void testReadChunkByChunk() throws Exception { + TestReadStream source = new TestReadStream(); + InputStreamAdapter stream = new InputStreamAdapter(source); + + for (String chunk : TEXT.split("\n")) { + source.append(chunk); + byte[] bytes = new byte[chunk.length() + 2]; + int read = stream.read(bytes, 1, chunk.length()); + Assertions.assertEquals(chunk.length(), read); + Assertions.assertEquals(chunk, new String(bytes, 1, chunk.length(), StandardCharsets.UTF_8)); + } + + source.end(); + Assertions.assertEquals(-1, stream.read()); + + stream.close(); + Assertions.assertThrows(IOException.class, stream::read); + } + + @Test + void testError() { + TestReadStream source = new TestReadStream(); + InputStreamAdapter stream = new InputStreamAdapter(source); + + source.append(TEXT); + source.error(new IllegalAccessError("NotAccess")); + + Assertions.assertThrows(IOException.class, stream::read); + } + + private static class TestReadStream implements ReadStream { + + private Handler dataHandler; + private Handler endHandler; + private Handler errorHandler; + + TestReadStream append(String text) { + dataHandler.handle(Buffer.buffer(text)); + return this; + } + + TestReadStream end() { + endHandler.handle(null); + return this; + } + + TestReadStream error(Throwable error) { + errorHandler.handle(error); + return this; + } + + @Override + public TestReadStream handler(Handler dataHandler) { + this.dataHandler = dataHandler; + return this; + } + + @Override + public TestReadStream endHandler(Handler endHandler) { + this.endHandler = endHandler; + return this; + } + + @Override + public TestReadStream exceptionHandler(Handler handler) { + this.errorHandler = handler; + return this; + } + + @Override + public TestReadStream pause() { + return this; + } + + @Override + public TestReadStream resume() { + return this; + } + + @Override + public TestReadStream fetch(long amount) { + return this; + } + } + +} diff --git a/storage/src/main/java/com/epam/aidial/core/storage/http/HttpException.java b/storage/src/main/java/com/epam/aidial/core/storage/http/HttpException.java index 5f643e09a..86900272e 100644 --- a/storage/src/main/java/com/epam/aidial/core/storage/http/HttpException.java +++ b/storage/src/main/java/com/epam/aidial/core/storage/http/HttpException.java @@ -11,6 +11,10 @@ public class HttpException extends RuntimeException { private final HttpStatus status; private final Map headers; + public HttpException(int status, String message) { + this(HttpStatus.fromStatusCode(status, HttpStatus.INTERNAL_SERVER_ERROR), message, Map.of()); + } + public HttpException(HttpStatus status, String message) { this(status, message, Map.of()); } diff --git a/storage/src/main/java/com/epam/aidial/core/storage/http/HttpStatus.java b/storage/src/main/java/com/epam/aidial/core/storage/http/HttpStatus.java index fb04c8cdc..a9308ff61 100644 --- a/storage/src/main/java/com/epam/aidial/core/storage/http/HttpStatus.java +++ b/storage/src/main/java/com/epam/aidial/core/storage/http/HttpStatus.java @@ -33,6 +33,10 @@ public boolean is5xx() { } public static HttpStatus fromStatusCode(int code) { + return fromStatusCode(code, INTERNAL_SERVER_ERROR); + } + + public static HttpStatus fromStatusCode(int code, HttpStatus fallback) { return switch (code) { case 200 -> OK; case 304 -> NOT_MODIFIED; @@ -52,7 +56,7 @@ public static HttpStatus fromStatusCode(int code) { case 503 -> SERVICE_UNAVAILABLE; case 504 -> GATEWAY_TIMEOUT; case 505 -> HTTP_VERSION_NOT_SUPPORTED; - default -> throw new IllegalArgumentException("Unknown HTTP status code: " + code); + default -> fallback; }; } } From b3cac113baa78fc5e64485fe8f02275152225b1f Mon Sep 17 00:00:00 2001 From: Artsiom Korzun Date: Fri, 10 Jan 2025 15:32:14 +0100 Subject: [PATCH 3/5] address review comments --- .../controller/CodeInterpreterController.java | 4 +- ...ava => CodeInterpreterExecuteRequest.java} | 2 +- ...va => CodeInterpreterExecuteResponse.java} | 2 +- .../CodeInterpreterClient.java | 20 +++++---- .../CodeInterpreterService.java | 42 ++++++++++--------- .../vertx/stream/InputStreamAdapterTest.java | 30 +++++++++++++ 6 files changed, 69 insertions(+), 31 deletions(-) rename server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/{CodeInterpreterExecute.java => CodeInterpreterExecuteRequest.java} (89%) rename server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/{CodeInterpreterExecution.java => CodeInterpreterExecuteResponse.java} (87%) diff --git a/server/src/main/java/com/epam/aidial/core/server/controller/CodeInterpreterController.java b/server/src/main/java/com/epam/aidial/core/server/controller/CodeInterpreterController.java index aad1de405..bf88c337c 100644 --- a/server/src/main/java/com/epam/aidial/core/server/controller/CodeInterpreterController.java +++ b/server/src/main/java/com/epam/aidial/core/server/controller/CodeInterpreterController.java @@ -1,7 +1,7 @@ package com.epam.aidial.core.server.controller; import com.epam.aidial.core.server.ProxyContext; -import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecute; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecuteRequest; import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFile; import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterInputFile; import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterOutputFile; @@ -69,7 +69,7 @@ Future executeCode() { context.getRequest() .body() .compose(body -> { - CodeInterpreterExecute data = convertJson(body, CodeInterpreterExecute.class); + CodeInterpreterExecuteRequest data = convertJson(body, CodeInterpreterExecuteRequest.class); return vertx.executeBlocking(() -> service.executeCode(context, data), false); }) .onSuccess(this::respondJson) diff --git a/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecute.java b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecuteRequest.java similarity index 89% rename from server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecute.java rename to server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecuteRequest.java index 4465e0f08..4edaeeec5 100644 --- a/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecute.java +++ b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecuteRequest.java @@ -7,7 +7,7 @@ @Data @JsonInclude(JsonInclude.Include.NON_NULL) -public class CodeInterpreterExecute { +public class CodeInterpreterExecuteRequest { private String sessionId; private String code; private List inputFiles; diff --git a/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecution.java b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecuteResponse.java similarity index 87% rename from server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecution.java rename to server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecuteResponse.java index 0251c253c..28c85f154 100644 --- a/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecution.java +++ b/server/src/main/java/com/epam/aidial/core/server/data/codeinterpreter/CodeInterpreterExecuteResponse.java @@ -5,7 +5,7 @@ @Data @JsonInclude(JsonInclude.Include.NON_NULL) -public class CodeInterpreterExecution { +public class CodeInterpreterExecuteResponse { private String status; private String stdout; private String stderr; diff --git a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterClient.java b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterClient.java index fdbe83c7e..62c9ce1c7 100644 --- a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterClient.java +++ b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterClient.java @@ -1,6 +1,6 @@ package com.epam.aidial.core.server.service.codeinterpreter; -import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecution; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecuteResponse; import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFile; import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFiles; import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSession; @@ -31,11 +31,11 @@ public class CodeInterpreterClient { // Vertx HttpClient does not support multipart upload, Vertx WebClient supports only Buffer as body for multipart upload private final HttpClient client = HttpClients.createDefault(); - private final long timeout; + private final long responseTimeout; - CodeInterpreterExecution executeCode(CodeInterpreterSession session, String code) { + CodeInterpreterExecuteResponse executeCode(CodeInterpreterSession session, String code) { Map body = Map.of("code", code); - return execute(session, "/execute_code", body, CodeInterpreterExecution.class); + return execute(session, "/execute_code", body, CodeInterpreterExecuteResponse.class); } CodeInterpreterFiles listFiles(CodeInterpreterSession session) { @@ -46,7 +46,7 @@ CodeInterpreterFiles listFiles(CodeInterpreterSession session) { @SneakyThrows CodeInterpreterFile uploadFile(CodeInterpreterSession session, InputStream source, String target) { HttpPost post = new HttpPost(session.getDeploymentUrl() + "/upload_file"); - post.setConfig(RequestConfig.custom().setResponseTimeout(timeout, TimeUnit.MILLISECONDS).build()); + post.setConfig(createRequestConfig()); post.setEntity(MultipartEntityBuilder.create() .addBinaryBody("file", source, ContentType.APPLICATION_OCTET_STREAM, target) .build()); @@ -66,7 +66,7 @@ CodeInterpreterFile uploadFile(CodeInterpreterSession session, InputStream sourc @SneakyThrows R downloadFile(CodeInterpreterSession session, String path, DownloadFileFunction consumer) { HttpPost post = new HttpPost(session.getDeploymentUrl() + "/download_file"); - post.setConfig(RequestConfig.custom().setResponseTimeout(timeout, TimeUnit.MILLISECONDS).build()); + post.setConfig(createRequestConfig()); post.setEntity(HttpEntities.create(ProxyUtil.convertToString(Map.of("path", path)), ContentType.APPLICATION_JSON)); return client.execute(post, response -> { @@ -87,7 +87,7 @@ R downloadFile(CodeInterpreterSession session, String path, DownloadFileFunc .onSuccess(result::complete) .onFailure(result::completeExceptionally); - return result.get(timeout, TimeUnit.MILLISECONDS); + return result.get(responseTimeout, TimeUnit.MILLISECONDS); } catch (Throwable e) { EntityUtils.consumeQuietly(entity); throw new HttpException(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to download file: " + path); @@ -98,7 +98,7 @@ R downloadFile(CodeInterpreterSession session, String path, DownloadFileFunc @SneakyThrows private R execute(CodeInterpreterSession session, String path, Object requestPayload, Class responseType) { HttpPost post = new HttpPost(session.getDeploymentUrl() + path); - post.setConfig(RequestConfig.custom().setResponseTimeout(timeout, TimeUnit.MILLISECONDS).build()); + post.setConfig(createRequestConfig()); post.setEntity(HttpEntities.create(ProxyUtil.convertToString(requestPayload), ContentType.APPLICATION_JSON)); return client.execute(post, response -> { @@ -113,6 +113,10 @@ private R execute(CodeInterpreterSession session, String path, Object reques }); } + private RequestConfig createRequestConfig() { + return RequestConfig.custom().setResponseTimeout(responseTimeout, TimeUnit.MILLISECONDS).build(); + } + public interface DownloadFileFunction { Future apply(InputStream stream, long size) throws Throwable; } diff --git a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterService.java b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterService.java index 3c3a6e4a1..5db1f6648 100644 --- a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterService.java +++ b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterService.java @@ -3,8 +3,8 @@ import com.epam.aidial.core.server.ProxyContext; import com.epam.aidial.core.server.data.AuthBucket; import com.epam.aidial.core.server.data.ResourceTypes; -import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecute; -import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecution; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecuteRequest; +import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecuteResponse; import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFile; import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFiles; import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterInputFile; @@ -210,7 +210,7 @@ public CodeInterpreterSession closeSession(ProxyContext context, String sessionI } } - public CodeInterpreterExecution executeCode(ProxyContext context, CodeInterpreterExecute request) { + public CodeInterpreterExecuteResponse executeCode(ProxyContext context, CodeInterpreterExecuteRequest request) { verifyActive(); verifyCode(request); @@ -223,27 +223,31 @@ public CodeInterpreterExecution executeCode(ProxyContext context, CodeInterprete session = touchSession(context, request.getSessionId()); } - if (request.getInputFiles() != null) { - for (CodeInterpreterInputFile input : request.getInputFiles()) { - input.setSessionId(session.getSessionId()); - transferInputFile(context, input); + try { + if (request.getInputFiles() != null) { + for (CodeInterpreterInputFile input : request.getInputFiles()) { + input.setSessionId(session.getSessionId()); + transferInputFile(context, input); + } } - } - CodeInterpreterExecution response = client.executeCode(session, request.getCode()); + CodeInterpreterExecuteResponse response = client.executeCode(session, request.getCode()); - if (request.getOutputFiles() != null) { - for (CodeInterpreterOutputFile output : request.getOutputFiles()) { - output.setSessionId(session.getSessionId()); - transferOutputFile(context, output); + if (request.getOutputFiles() != null) { + for (CodeInterpreterOutputFile output : request.getOutputFiles()) { + output.setSessionId(session.getSessionId()); + transferOutputFile(context, output); + } } - } - if (anonymous) { - closeSession(context, session.getSessionId()); - } + return response; + } catch (Throwable e) { + if (anonymous) { + closeSession(context, session.getSessionId()); + } - return response; + throw e; + } } @SneakyThrows @@ -364,7 +368,7 @@ private static void verifySessionId(String sessionId) { } } - private static void verifyCode(CodeInterpreterExecute request) { + private static void verifyCode(CodeInterpreterExecuteRequest request) { if (request.getCode() == null) { throw new IllegalArgumentException("Missing code"); } diff --git a/server/src/test/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapterTest.java b/server/src/test/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapterTest.java index 2e3d1f6b9..058fe2ca6 100644 --- a/server/src/test/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapterTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/vertx/stream/InputStreamAdapterTest.java @@ -1,13 +1,19 @@ package com.epam.aidial.core.server.vertx.stream; import io.vertx.core.Handler; +import io.vertx.core.Vertx; import io.vertx.core.buffer.Buffer; +import io.vertx.core.file.AsyncFile; +import io.vertx.core.file.OpenOptions; import io.vertx.core.streams.ReadStream; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.concurrent.CompletableFuture; class InputStreamAdapterTest { @@ -70,6 +76,30 @@ void testReadChunkByChunk() throws Exception { Assertions.assertThrows(IOException.class, stream::read); } + @Test + void testBigFile() throws Exception { + Path file = Files.createTempFile("input-stream-test", ".txt"); + String text = "1".repeat(64 * 1024 * 1024); + Vertx vertx = Vertx.vertx(); + + try { + Files.writeString(file, text); + + CompletableFuture future = new CompletableFuture<>(); + vertx.fileSystem().open(file.toString(), new OpenOptions()) + .onSuccess(future::complete) + .onFailure(future::completeExceptionally); + + AsyncFile source = future.get(); + InputStreamAdapter stream = new InputStreamAdapter(source); + + Assertions.assertEquals(text, new String(stream.readAllBytes(), StandardCharsets.UTF_8)); + } finally { + Files.deleteIfExists(file); + vertx.close(); + } + } + @Test void testError() { TestReadStream source = new TestReadStream(); From e7714c36c6e5d06e652f3b433ba141c8c15b7071 Mon Sep 17 00:00:00 2001 From: Artsiom Korzun Date: Fri, 10 Jan 2025 15:39:30 +0100 Subject: [PATCH 4/5] address review comments --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1ff095569..07d9e29e7 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ Priority order: | encryption.secret | - | No |Secret is used for AES encryption of a prefix to the bucket blob storage. The value should be random generated string. | encryption.key | - | No |Key is used for AES encryption of a prefix to the bucket blob storage. The value should be random generated string. | resources.maxSize | 67108864 | No |Max allowed size in bytes for a resource. -| resources.maxSizeToCache | 1048576 | No |Max size in bytes for a resource to cache in Redis. +| resources.maxSizeToCache | 1048576 | No |Max size in bytes for a resource to cache in Redis. | resources.syncPeriod | 60000 | No |Period in milliseconds, how frequently check for resources to sync. | resources.syncDelay | 120000 | No |Delay in milliseconds for a resource to be written back in object storage after last modification. | resources.syncBatch | 4096 | No |How many resources to sync in one go. @@ -97,7 +97,10 @@ Priority order: | applications.includeCustomApps | false | No |The flag indicates whether custom applications should be included into openai listing | applications.controllerEndpoint | - | No |The endpoint to Application Controller Web Service that manages deployments for applications with functions | applications.controllerTimeout | 240000 | No |The timeout of operations to Application Controller Web Service -| applications.checkPeriod | 300000 | No |The interval at which to check the pending operations for applications with functions +| codeInterpreter.sessionImage | - | No |The code interpreter session image to use +| codeInterpreter.sessionTtl | 600000 | No |The session time to leave after the last API call +| codeInterpreter.checkPeriod | 10000 | No |The interval at which to check active sessions for expiration +| codeInterpreter.checkSize | 256 | No |The maximum number of active sessions to check in single check ### Storage requirements From 48ccd06f2d8991a95776bf24355ac2b0d1de289e Mon Sep 17 00:00:00 2001 From: Artsiom Korzun Date: Fri, 10 Jan 2025 16:02:00 +0100 Subject: [PATCH 5/5] address review comments --- .../service/codeinterpreter/CodeInterpreterService.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterService.java b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterService.java index 5db1f6648..562b316de 100644 --- a/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterService.java +++ b/server/src/main/java/com/epam/aidial/core/server/service/codeinterpreter/CodeInterpreterService.java @@ -241,12 +241,10 @@ public CodeInterpreterExecuteResponse executeCode(ProxyContext context, CodeInte } return response; - } catch (Throwable e) { + } finally { if (anonymous) { closeSession(context, session.getSessionId()); } - - throw e; } }