Skip to content

Commit

Permalink
add code interpreter feature
Browse files Browse the repository at this point in the history
  • Loading branch information
artsiomkorzun committed Jan 9, 2025
1 parent 6f03973 commit 29f3079
Show file tree
Hide file tree
Showing 26 changed files with 1,225 additions and 20 deletions.
2 changes: 1 addition & 1 deletion server/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ dependencies {
implementation 'org.hibernate.validator:hibernate-validator:8.0.0.Final'
implementation 'org.glassfish:jakarta.el:4.0.2'
implementation 'jakarta.validation:jakarta.validation-api:3.0.2' // Ensure you have Jakarta Validation API dependency
implementation 'org.apache.httpcomponents.client5:httpclient5:5.4'

testImplementation 'org.junit.jupiter:junit-jupiter-api:5.9.3'
testImplementation 'commons-io:commons-io:2.11.0'
testImplementation 'io.vertx:vertx-web-client:4.5.10'
testImplementation 'io.vertx:vertx-junit5:4.5.10'
testImplementation 'org.mockito:mockito-core:5.7.0'
testImplementation 'org.mockito:mockito-junit-jupiter:5.7.0'
testImplementation 'org.apache.httpcomponents.client5:httpclient5:5.4'
testImplementation('com.github.codemonstur:embedded-redis:1.4.3') {
exclude group: 'org.slf4j', module: 'slf4j-simple'
}
Expand Down
12 changes: 8 additions & 4 deletions server/src/main/java/com/epam/aidial/core/server/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.epam.aidial.core.server.security.AccessTokenValidator;
import com.epam.aidial.core.server.security.ApiKeyStore;
import com.epam.aidial.core.server.security.EncryptionService;
import com.epam.aidial.core.server.service.ApplicationOperatorService;
import com.epam.aidial.core.server.service.ApplicationService;
import com.epam.aidial.core.server.service.HeartbeatService;
import com.epam.aidial.core.server.service.InvitationService;
Expand All @@ -18,6 +19,7 @@
import com.epam.aidial.core.server.service.RuleService;
import com.epam.aidial.core.server.service.ShareService;
import com.epam.aidial.core.server.service.VertxTimerService;
import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService;
import com.epam.aidial.core.server.token.TokenStatsTracker;
import com.epam.aidial.core.server.tracing.DialTracingFactory;
import com.epam.aidial.core.server.upstream.UpstreamRouteProvider;
Expand Down Expand Up @@ -122,8 +124,9 @@ void start() throws Exception {
InvitationService invitationService = new InvitationService(resourceService, encryptionService, settings("invitations"));
ApiKeyStore apiKeyStore = new ApiKeyStore(resourceService, vertx);
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);
ApplicationService applicationService = new ApplicationService(vertx, client, redis,
encryptionService, resourceService, lockService, generator, settings("applications"));
ApplicationOperatorService operatorService = new ApplicationOperatorService(client, settings("applications"));
ApplicationService applicationService = new ApplicationService(vertx, redis, encryptionService,
resourceService, lockService, operatorService, generator, settings("applications"));
ShareService shareService = new ShareService(resourceService, invitationService, encryptionService, applicationService, configStore);
RuleService ruleService = new RuleService(resourceService);
AccessService accessService = new AccessService(encryptionService, shareService, ruleService, settings("access"));
Expand All @@ -133,7 +136,8 @@ void start() throws Exception {
PublicationService publicationService = new PublicationService(encryptionService, resourceService, accessService,
ruleService, notificationService, applicationService, resourceOperationService, generator, clock);
RateLimiter rateLimiter = new RateLimiter(vertx, resourceService);

CodeInterpreterService codeInterpreterService = new CodeInterpreterService(vertx, redis, resourceService,
accessService, encryptionService, operatorService, generator, settings("codeInterpreter"));

TokenStatsTracker tokenStatsTracker = new TokenStatsTracker(vertx, resourceService);

Expand All @@ -143,7 +147,7 @@ void start() throws Exception {
rateLimiter, upstreamRouteProvider, accessTokenValidator,
storage, encryptionService, apiKeyStore, tokenStatsTracker, resourceService, invitationService,
shareService, publicationService, accessService, lockService, resourceOperationService, ruleService,
notificationService, applicationService, heartbeatService, version());
notificationService, applicationService, codeInterpreterService, heartbeatService, version());

server = vertx.createHttpServer(new HttpServerOptions(settings("server"))).requestHandler(proxy);
open(server, HttpServer::listen);
Expand Down
2 changes: 2 additions & 0 deletions server/src/main/java/com/epam/aidial/core/server/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.epam.aidial.core.server.service.ResourceOperationService;
import com.epam.aidial.core.server.service.RuleService;
import com.epam.aidial.core.server.service.ShareService;
import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService;
import com.epam.aidial.core.server.token.TokenStatsTracker;
import com.epam.aidial.core.server.upstream.UpstreamRouteProvider;
import com.epam.aidial.core.server.util.ProxyUtil;
Expand Down Expand Up @@ -88,6 +89,7 @@ public class Proxy implements Handler<HttpServerRequest> {
private final RuleService ruleService;
private final NotificationService notificationService;
private final ApplicationService applicationService;
private final CodeInterpreterService codeInterpreterService;
private final HeartbeatService heartbeatService;
private final String version;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,20 @@ public Future<?> respond(HttpStatus status, String contentType, Object object) {
}

public Future<?> respond(HttpStatus status, String body) {
return respond(status.getCode(), body);
}

public Future<?> respond(int status, String body) {
if (body == null) {
body = "";
}

if (status != HttpStatus.OK) {
if (status != HttpStatus.OK.getCode()) {
log.warn("Responding with error. Project: {}. Trace: {}. Span: {}. Status: {}. Body: {}", getProject(), traceId, spanId, status,
body.length() > LOG_MAX_ERROR_LENGTH ? body.substring(0, LOG_MAX_ERROR_LENGTH) : body);
}

response.setStatusCode(status.getCode()).end(body);
response.setStatusCode(status).end(body);
return Future.succeededFuture();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package com.epam.aidial.core.server.controller;

import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecute;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFile;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterInputFile;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterOutputFile;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSession;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSessionId;
import com.epam.aidial.core.server.service.PermissionDeniedException;
import com.epam.aidial.core.server.service.ResourceNotFoundException;
import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterError;
import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.server.vertx.stream.InputStreamAdapter;
import com.epam.aidial.core.server.vertx.stream.InputStreamReader;
import com.epam.aidial.core.storage.http.HttpException;
import com.epam.aidial.core.storage.http.HttpStatus;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpServerFileUpload;
import io.vertx.core.http.HttpServerResponse;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.io.InputStream;

@Slf4j
class CodeInterpreterController {

private final ProxyContext context;
private final Vertx vertx;
private final CodeInterpreterService service;

public CodeInterpreterController(ProxyContext context) {
this.context = context;
this.vertx = context.getProxy().getVertx();
this.service = context.getProxy().getCodeInterpreterService();
}

Future<?> openSession() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterSessionId data = convertJson(body, CodeInterpreterSessionId.class);
return vertx.executeBlocking(() -> service.openSession(context, data.getSessionId()), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> closeSession() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterSessionId data = convertJson(body, CodeInterpreterSessionId.class);
return vertx.executeBlocking(() -> service.closeSession(context, data.getSessionId()), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> executeCode() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterExecute data = convertJson(body, CodeInterpreterExecute.class);
return vertx.executeBlocking(() -> service.executeCode(context, data), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> uploadFile() {
context.getRequest()
.setExpectMultipart(true)
.uploadHandler(upload -> {
// do not move inside execute blocking, otherwise you can miss the beginning of file
InputStream stream = new InputStreamAdapter(upload);
vertx.executeBlocking(() -> uploadFile(upload, stream), false)
.onSuccess(this::respondJson)
.onFailure(this::respondError);
});

return Future.succeededFuture();
}

@SneakyThrows
private CodeInterpreterFile uploadFile(HttpServerFileUpload upload, InputStream stream) {
String sessionId = context.getRequest().getParam("session_id");
String fileName = upload.filename();

if (sessionId == null) {
throw new IllegalArgumentException("Missing session_id query param");
}

if (fileName == null) {
throw new IllegalArgumentException("Missing filename in multipart upload");
}

return service.uploadFile(context, sessionId, fileName, stream);
}

Future<?> downloadFile() {
context.getRequest().body()
.compose(buffer -> vertx.executeBlocking(() -> downloadFile(buffer), false))
.onFailure(this::respondError);

return Future.succeededFuture();
}

private Void downloadFile(Buffer body) {
CodeInterpreterFile data = convertJson(body, CodeInterpreterFile.class);
HttpServerResponse response = context.getResponse();

return service.downloadFile(context, data.getSessionId(), data.getPath(), (stream, size) -> {
response.putHeader(HttpHeaders.CONTENT_LENGTH, Long.toString(size));
return new InputStreamReader(vertx, stream)
.pipe()
.endOnFailure(false)
.to(response);
});
}

Future<?> listFiles() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterSessionId data = convertJson(body, CodeInterpreterSessionId.class);
return vertx.executeBlocking(() -> service.listFiles(context, data.getSessionId()), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> transferInputFile() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterInputFile data = convertJson(body, CodeInterpreterInputFile.class);
return vertx.executeBlocking(() -> service.transferInputFile(context, data), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> transferOutputFile() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterOutputFile data = convertJson(body, CodeInterpreterOutputFile.class);
return vertx.executeBlocking(() -> service.transferOutputFile(context, data), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

private void respondJson(Object data) {
if (data instanceof CodeInterpreterSession session) {
session.setDeploymentId(null);
session.setDeploymentUrl(null);
session.setUsedAt(null);
}

context.respond(HttpStatus.OK, data);
}

private void respondError(Throwable error) {
HttpServerResponse response = context.getResponse();
if (response.headWritten()) {
response.reset();
} else if (error instanceof IllegalArgumentException) {
context.respond(HttpStatus.BAD_REQUEST, error.getMessage());
} else if (error instanceof PermissionDeniedException) {
context.respond(HttpStatus.FORBIDDEN, error.getMessage());
} else if (error instanceof ResourceNotFoundException) {
context.respond(HttpStatus.NOT_FOUND, error.getMessage());
} else if (error instanceof HttpException e) {
context.respond(e.getStatus(), e.getMessage());
} else if (error instanceof CodeInterpreterError e) {
context.respond(e.getStatus(), e.getMessage());
} else {
log.error("Failed to handle code interpreter request", error);
context.respond(error, "Internal error");
}
}

private static <T> T convertJson(Buffer body, Class<T> clazz) {
try {
T result = ProxyUtil.convertToObject(body, clazz);

if (result == null) {
throw new IllegalArgumentException("No JSON body");
}

return result;
} catch (Exception e) {
throw new IllegalArgumentException("Not valid JSON body");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ public class ControllerSelector {
private static final Pattern USER_INFO = Pattern.compile("^/v1/user/info$");

private static final Pattern APP_SCHEMAS = Pattern.compile("^/v1/application_type_schemas/(schemas|schema|meta_schema)?");
private static final Pattern CODE_INTERPRETER = Pattern.compile("^/v1/ops/code_interpreter/"
+ "(open_session|close_session|execute_code|"
+ "upload_file|download_file|list_files|"
+ "transfer_input_file|transfer_output_file)$");

static {
// GET routes
Expand Down Expand Up @@ -283,6 +287,22 @@ public class ControllerSelector {
default -> null;
};
});
post(CODE_INTERPRETER, (proxy, context, pathMatcher) -> {
String operation = pathMatcher.group(1);
CodeInterpreterController controller = new CodeInterpreterController(context);

return switch (operation) {
case "open_session" -> controller::openSession;
case "close_session" -> controller::closeSession;
case "execute_code" -> controller::executeCode;
case "upload_file" -> controller::uploadFile;
case "download_file" -> controller::downloadFile;
case "list_files" -> controller::listFiles;
case "transfer_input_file" -> controller::transferInputFile;
case "transfer_output_file" -> controller::transferOutputFile;
default -> null;
};
});
// DELETE routes
delete(PATTERN_FILES, (proxy, context, pathMatcher) -> {
ResourceController controller = new ResourceController(proxy, context, false);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.epam.aidial.core.server.data;

import lombok.Data;

import javax.annotation.Nullable;


@Data
public class AuthBucket {

/**
* The encrypted bucket location for the original JWT or API_KEY.
*/
String userBucket;
/**
* The bucket location for the original JWT or API_KEY.
*/
String userBucketLocation;

/**
* The encrypted bucket location for the application from PER_REQUEST_KEY if present.
*/
@Nullable
String appBucket;

/**
* The bucket location for the application from PER_REQUEST_KEY if present.
*/
@Nullable
String appBucketLocation;
}
Loading

0 comments on commit 29f3079

Please sign in to comment.