Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: code interpreter #633

Merged
merged 6 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Priority order:
| encryption.secret | - | No |Secret is used for AES encryption of a prefix to the bucket blob storage. The value should be random generated string.
| encryption.key | - | No |Key is used for AES encryption of a prefix to the bucket blob storage. The value should be random generated string.
| resources.maxSize | 67108864 | No |Max allowed size in bytes for a resource.
| resources.maxSizeToCache | 1048576 | No |Max size in bytes for a resource to cache in Redis.
| resources.maxSizeToCache | 1048576 | No |Max size in bytes for a resource to cache in Redis.
| resources.syncPeriod | 60000 | No |Period in milliseconds, how frequently check for resources to sync.
| resources.syncDelay | 120000 | No |Delay in milliseconds for a resource to be written back in object storage after last modification.
| resources.syncBatch | 4096 | No |How many resources to sync in one go.
Expand All @@ -97,7 +97,10 @@ Priority order:
| applications.includeCustomApps | false | No |The flag indicates whether custom applications should be included into openai listing
| applications.controllerEndpoint | - | No |The endpoint to Application Controller Web Service that manages deployments for applications with functions
| applications.controllerTimeout | 240000 | No |The timeout of operations to Application Controller Web Service
| applications.checkPeriod | 300000 | No |The interval at which to check the pending operations for applications with functions
| codeInterpreter.sessionImage | - | No |The code interpreter session image to use
| codeInterpreter.sessionTtl | 600000 | No |The session time to leave after the last API call
| codeInterpreter.checkPeriod | 10000 | No |The interval at which to check active sessions for expiration
| codeInterpreter.checkSize | 256 | No |The maximum number of active sessions to check in single check

### Storage requirements

Expand Down
2 changes: 1 addition & 1 deletion server/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ dependencies {
implementation 'org.hibernate.validator:hibernate-validator:8.0.0.Final'
implementation 'org.glassfish:jakarta.el:4.0.2'
implementation 'jakarta.validation:jakarta.validation-api:3.0.2' // Ensure you have Jakarta Validation API dependency
implementation 'org.apache.httpcomponents.client5:httpclient5:5.4'
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved

testImplementation 'org.junit.jupiter:junit-jupiter-api:5.9.3'
testImplementation 'commons-io:commons-io:2.11.0'
testImplementation 'io.vertx:vertx-web-client:4.5.10'
testImplementation 'io.vertx:vertx-junit5:4.5.10'
testImplementation 'org.mockito:mockito-core:5.7.0'
testImplementation 'org.mockito:mockito-junit-jupiter:5.7.0'
testImplementation 'org.apache.httpcomponents.client5:httpclient5:5.4'
testImplementation('com.github.codemonstur:embedded-redis:1.4.3') {
exclude group: 'org.slf4j', module: 'slf4j-simple'
}
Expand Down
12 changes: 8 additions & 4 deletions server/src/main/java/com/epam/aidial/core/server/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.epam.aidial.core.server.security.AccessTokenValidator;
import com.epam.aidial.core.server.security.ApiKeyStore;
import com.epam.aidial.core.server.security.EncryptionService;
import com.epam.aidial.core.server.service.ApplicationOperatorService;
import com.epam.aidial.core.server.service.ApplicationService;
import com.epam.aidial.core.server.service.HeartbeatService;
import com.epam.aidial.core.server.service.InvitationService;
Expand All @@ -18,6 +19,7 @@
import com.epam.aidial.core.server.service.RuleService;
import com.epam.aidial.core.server.service.ShareService;
import com.epam.aidial.core.server.service.VertxTimerService;
import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService;
import com.epam.aidial.core.server.token.TokenStatsTracker;
import com.epam.aidial.core.server.tracing.DialTracingFactory;
import com.epam.aidial.core.server.upstream.UpstreamRouteProvider;
Expand Down Expand Up @@ -122,8 +124,9 @@ void start() throws Exception {
InvitationService invitationService = new InvitationService(resourceService, encryptionService, settings("invitations"));
ApiKeyStore apiKeyStore = new ApiKeyStore(resourceService, vertx);
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);
ApplicationService applicationService = new ApplicationService(vertx, client, redis,
encryptionService, resourceService, lockService, generator, settings("applications"));
ApplicationOperatorService operatorService = new ApplicationOperatorService(client, settings("applications"));
ApplicationService applicationService = new ApplicationService(vertx, redis, encryptionService,
resourceService, lockService, operatorService, generator, settings("applications"));
ShareService shareService = new ShareService(resourceService, invitationService, encryptionService, applicationService, configStore);
RuleService ruleService = new RuleService(resourceService);
AccessService accessService = new AccessService(encryptionService, shareService, ruleService, settings("access"));
Expand All @@ -133,7 +136,8 @@ void start() throws Exception {
PublicationService publicationService = new PublicationService(encryptionService, resourceService, accessService,
ruleService, notificationService, applicationService, resourceOperationService, generator, clock);
RateLimiter rateLimiter = new RateLimiter(vertx, resourceService);

CodeInterpreterService codeInterpreterService = new CodeInterpreterService(vertx, redis, resourceService,
accessService, encryptionService, operatorService, generator, settings("codeInterpreter"));

TokenStatsTracker tokenStatsTracker = new TokenStatsTracker(vertx, resourceService);

Expand All @@ -143,7 +147,7 @@ void start() throws Exception {
rateLimiter, upstreamRouteProvider, accessTokenValidator,
storage, encryptionService, apiKeyStore, tokenStatsTracker, resourceService, invitationService,
shareService, publicationService, accessService, lockService, resourceOperationService, ruleService,
notificationService, applicationService, heartbeatService, version());
notificationService, applicationService, codeInterpreterService, heartbeatService, version());

server = vertx.createHttpServer(new HttpServerOptions(settings("server"))).requestHandler(proxy);
open(server, HttpServer::listen);
Expand Down
2 changes: 2 additions & 0 deletions server/src/main/java/com/epam/aidial/core/server/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.epam.aidial.core.server.service.ResourceOperationService;
import com.epam.aidial.core.server.service.RuleService;
import com.epam.aidial.core.server.service.ShareService;
import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService;
import com.epam.aidial.core.server.token.TokenStatsTracker;
import com.epam.aidial.core.server.upstream.UpstreamRouteProvider;
import com.epam.aidial.core.server.util.ProxyUtil;
Expand Down Expand Up @@ -88,6 +89,7 @@ public class Proxy implements Handler<HttpServerRequest> {
private final RuleService ruleService;
private final NotificationService notificationService;
private final ApplicationService applicationService;
private final CodeInterpreterService codeInterpreterService;
private final HeartbeatService heartbeatService;
private final String version;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package com.epam.aidial.core.server.controller;

import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterExecuteRequest;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterFile;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterInputFile;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterOutputFile;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSession;
import com.epam.aidial.core.server.data.codeinterpreter.CodeInterpreterSessionId;
import com.epam.aidial.core.server.service.PermissionDeniedException;
import com.epam.aidial.core.server.service.ResourceNotFoundException;
import com.epam.aidial.core.server.service.codeinterpreter.CodeInterpreterService;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.server.vertx.stream.InputStreamAdapter;
import com.epam.aidial.core.server.vertx.stream.InputStreamReader;
import com.epam.aidial.core.storage.http.HttpException;
import com.epam.aidial.core.storage.http.HttpStatus;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpServerFileUpload;
import io.vertx.core.http.HttpServerResponse;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.io.InputStream;

@Slf4j
class CodeInterpreterController {

private final ProxyContext context;
private final Vertx vertx;
private final CodeInterpreterService service;

public CodeInterpreterController(ProxyContext context) {
this.context = context;
this.vertx = context.getProxy().getVertx();
this.service = context.getProxy().getCodeInterpreterService();
}

Future<?> openSession() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterSessionId data = convertJson(body, CodeInterpreterSessionId.class);
return vertx.executeBlocking(() -> service.openSession(context, data.getSessionId()), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> closeSession() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterSessionId data = convertJson(body, CodeInterpreterSessionId.class);
return vertx.executeBlocking(() -> service.closeSession(context, data.getSessionId()), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> executeCode() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterExecuteRequest data = convertJson(body, CodeInterpreterExecuteRequest.class);
return vertx.executeBlocking(() -> service.executeCode(context, data), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> uploadFile() {
context.getRequest()
.setExpectMultipart(true)
.uploadHandler(upload -> {
// do not move inside execute blocking, otherwise you can miss the beginning of file
InputStreamAdapter stream = new InputStreamAdapter(upload);
vertx.executeBlocking(() -> uploadFile(upload, stream), false)
.onSuccess(this::respondJson)
.onComplete(e -> stream.close())
.onFailure(this::respondError);
});

return Future.succeededFuture();
}

@SneakyThrows
private CodeInterpreterFile uploadFile(HttpServerFileUpload upload, InputStream stream) {
String sessionId = context.getRequest().getParam("session_id");
String fileName = upload.filename();

if (sessionId == null) {
throw new IllegalArgumentException("Missing session_id query param");
}

if (fileName == null) {
throw new IllegalArgumentException("Missing filename in multipart upload");
}

return service.uploadFile(context, sessionId, fileName, stream);
}

Future<?> downloadFile() {
context.getRequest().body()
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved
.compose(buffer -> vertx.executeBlocking(() -> downloadFile(buffer), false))
.onFailure(this::respondError);

return Future.succeededFuture();
}

private Void downloadFile(Buffer body) {
CodeInterpreterFile data = convertJson(body, CodeInterpreterFile.class);
HttpServerResponse response = context.getResponse();

return service.downloadFile(context, data.getSessionId(), data.getPath(), (stream, size) -> {
response.putHeader(HttpHeaders.CONTENT_LENGTH, Long.toString(size));
return new InputStreamReader(vertx, stream)
.pipe()
.endOnFailure(false)
.to(response);
});
}

Future<?> listFiles() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterSessionId data = convertJson(body, CodeInterpreterSessionId.class);
return vertx.executeBlocking(() -> service.listFiles(context, data.getSessionId()), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> transferInputFile() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterInputFile data = convertJson(body, CodeInterpreterInputFile.class);
return vertx.executeBlocking(() -> service.transferInputFile(context, data), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

Future<?> transferOutputFile() {
context.getRequest()
.body()
.compose(body -> {
CodeInterpreterOutputFile data = convertJson(body, CodeInterpreterOutputFile.class);
return vertx.executeBlocking(() -> service.transferOutputFile(context, data), false);
})
.onSuccess(this::respondJson)
.onFailure(this::respondError);

return Future.succeededFuture();
}

private void respondJson(Object data) {
if (data instanceof CodeInterpreterSession session) {
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved
session.setDeploymentId(null);
session.setDeploymentUrl(null);
session.setUsedAt(null);
}

context.respond(HttpStatus.OK, data);
}

private void respondError(Throwable error) {
HttpServerResponse response = context.getResponse();
if (response.headWritten()) {
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved
// download request can partially fail, when some data already is sent, it is too late to send response
// so the only option is to disconnect client
response.reset();
} else if (error instanceof IllegalArgumentException) {
context.respond(HttpStatus.BAD_REQUEST, error.getMessage());
} else if (error instanceof PermissionDeniedException) {
context.respond(HttpStatus.FORBIDDEN, error.getMessage());
} else if (error instanceof ResourceNotFoundException) {
context.respond(HttpStatus.NOT_FOUND, error.getMessage());
} else if (error instanceof HttpException e) {
context.respond(e.getStatus(), e.getMessage());
} else {
log.error("Failed to handle code interpreter request", error);
context.respond(error, "Internal error");
}
}

private static <T> T convertJson(Buffer body, Class<T> clazz) {
try {
T result = ProxyUtil.convertToObject(body, clazz);

if (result == null) {
throw new IllegalArgumentException("No JSON body");
}

return result;
} catch (Exception e) {
throw new IllegalArgumentException("Not valid JSON body");
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ public class ControllerSelector {
private static final Pattern USER_INFO = Pattern.compile("^/v1/user/info$");

private static final Pattern APP_SCHEMAS = Pattern.compile("^/v1/application_type_schemas/(schemas|schema|meta_schema)?");
private static final Pattern CODE_INTERPRETER = Pattern.compile("^/v1/ops/code_interpreter/"
+ "(open_session|close_session|execute_code|"
+ "upload_file|download_file|list_files|"
+ "transfer_input_file|transfer_output_file)$");

static {
// GET routes
Expand Down Expand Up @@ -283,6 +287,22 @@ public class ControllerSelector {
default -> null;
};
});
post(CODE_INTERPRETER, (proxy, context, pathMatcher) -> {
String operation = pathMatcher.group(1);
CodeInterpreterController controller = new CodeInterpreterController(context);

return switch (operation) {
case "open_session" -> controller::openSession;
case "close_session" -> controller::closeSession;
case "execute_code" -> controller::executeCode;
case "upload_file" -> controller::uploadFile;
case "download_file" -> controller::downloadFile;
case "list_files" -> controller::listFiles;
case "transfer_input_file" -> controller::transferInputFile;
case "transfer_output_file" -> controller::transferOutputFile;
default -> null;
};
});
// DELETE routes
delete(PATTERN_FILES, (proxy, context, pathMatcher) -> {
ResourceController controller = new ResourceController(proxy, context, false);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.epam.aidial.core.server.data;

import lombok.Data;

import javax.annotation.Nullable;


@Data
public class AuthBucket {

/**
* The encrypted bucket location for the original JWT or API_KEY.
*/
String userBucket;
/**
* The bucket location for the original JWT or API_KEY.
*/
String userBucketLocation;

/**
* The encrypted bucket location for the application from PER_REQUEST_KEY if present.
*/
@Nullable
String appBucket;

/**
* The bucket location for the application from PER_REQUEST_KEY if present.
*/
@Nullable
String appBucketLocation;
}
Loading
Loading