Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf test rollback #322

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/main/java/com/epam/aidial/core/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ void start() throws Exception {
vertx = Vertx.vertx(vertxOptions);
client = vertx.createHttpClient(new HttpClientOptions(settings("client")));

ApiKeyStore apiKeyStore = new ApiKeyStore();
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);
LogStore logStore = new GfLogStore(vertx);
UpstreamBalancer upstreamBalancer = new UpstreamBalancer();

Expand All @@ -119,9 +121,6 @@ void start() throws Exception {
AccessService accessService = new AccessService(encryptionService, shareService, publicationService, settings("access"));
RateLimiter rateLimiter = new RateLimiter(vertx, resourceService);

ApiKeyStore apiKeyStore = new ApiKeyStore(resourceService, lockService, vertx);
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);

proxy = new Proxy(vertx, client, configStore, logStore,
rateLimiter, upstreamBalancer, accessTokenValidator,
storage, encryptionService, apiKeyStore, tokenStatsTracker, resourceService, invitationService,
Expand Down
46 changes: 18 additions & 28 deletions src/main/java/com/epam/aidial/core/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -147,53 +147,43 @@ private void handleRequest(HttpServerRequest request) {
String traceId = spanContext.getTraceId();
String spanId = spanContext.getSpanId();
log.debug("Authorization header: {}", authorization);
Future<AuthorizationResult> authorizationResultFuture;

request.pause();
ApiKeyData apiKeyData;
if (apiKey == null && authorization == null) {
respond(request, HttpStatus.UNAUTHORIZED, "At least API-KEY or Authorization header must be provided");
return;
} else if (apiKey != null && authorization != null && !apiKey.equals(extractTokenFromHeader(authorization))) {
respond(request, HttpStatus.BAD_REQUEST, "Either API-KEY or Authorization header must be provided but not both");
return;
} else if (apiKey != null) {
authorizationResultFuture = apiKeyStore.getApiKeyData(apiKey)
.onFailure(error -> onGettingApiKeyDataFailure(error, request))
.compose(apiKeyData -> {
if (apiKeyData == null) {
String errorMessage = "Unknown api key";
respond(request, HttpStatus.UNAUTHORIZED, errorMessage);
return Future.failedFuture(errorMessage);
}
return Future.succeededFuture(new AuthorizationResult(apiKeyData, null));
});
apiKeyData = apiKeyStore.getApiKeyData(apiKey);
// Special case handling. OpenAI client sends both API key and Auth headers even if a caller sets just API Key only
// Auth header is set to the same value as API Key header
// ignore auth header in this case
authorization = null;
if (apiKeyData == null) {
respond(request, HttpStatus.UNAUTHORIZED, "Unknown api key");
return;
}
} else {
authorizationResultFuture = tokenValidator.extractClaims(authorization)
.onFailure(error -> onExtractClaimsFailure(error, request))
.map(extractedClaims -> new AuthorizationResult(new ApiKeyData(), extractedClaims));
apiKeyData = new ApiKeyData();
}

authorizationResultFuture.compose(result -> processAuthorizationResult(result.extractedClaims, config, request, result.apiKeyData, traceId, spanId))
.onComplete(ignore -> request.resume());
}

private record AuthorizationResult(ApiKeyData apiKeyData, ExtractedClaims extractedClaims) {
request.pause();
Future<ExtractedClaims> extractedClaims = tokenValidator.extractClaims(authorization);

extractedClaims.onFailure(error -> onExtractClaimsFailure(error, request))
.compose(claims -> onExtractClaimsSuccess(claims, config, request, apiKeyData, traceId, spanId))
.onComplete(ignore -> request.resume());
}

private void onExtractClaimsFailure(Throwable error, HttpServerRequest request) {
log.error("Can't extract claims from authorization header", error);
respond(request, HttpStatus.UNAUTHORIZED, "Bad Authorization header");
}

private void onGettingApiKeyDataFailure(Throwable error, HttpServerRequest request) {
log.error("Can't find data associated with API key", error);
respond(request, HttpStatus.UNAUTHORIZED, "Bad Authorization header");
}

@SneakyThrows
private Future<?> processAuthorizationResult(ExtractedClaims extractedClaims, Config config,
HttpServerRequest request, ApiKeyData apiKeyData, String traceId, String spanId) {
private Future<?> onExtractClaimsSuccess(ExtractedClaims extractedClaims, Config config,
HttpServerRequest request, ApiKeyData apiKeyData, String traceId, String spanId) {
Future<?> future;
try {
ProxyContext context = new ProxyContext(config, request, apiKeyData, extractedClaims, traceId, spanId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,28 +125,26 @@ private void handleRateLimitSuccess(String deploymentId) {
return;
}

ApiKeyData proxyApiKeyData = new ApiKeyData();
context.setProxyApiKeyData(proxyApiKeyData);
ApiKeyData.initFromContext(proxyApiKeyData, context);
long start = System.currentTimeMillis();
proxy.getApiKeyStore().assignApiKey(proxyApiKeyData);
log.info("Complete assigning per request api key for {} ms. Trace: {}. Span: {}. Key: {}. Deployment: {}",
System.currentTimeMillis() - start,
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName());

proxy.getTokenStatsTracker().startSpan(context);

context.getRequest().body()
.onSuccess(body -> proxy.getVertx().executeBlocking(() -> {
// run setting up api key data in the worker thread
setupProxyApiKeyData();
handleRequestBody(body);
return null;
}).onFailure(this::handleError))
}))
.onFailure(this::handleRequestBodyError);
}

/**
* The method uses blocking calls and should not be used in the event loop thread.
*/
private void setupProxyApiKeyData() {
ApiKeyData proxyApiKeyData = new ApiKeyData();
context.setProxyApiKeyData(proxyApiKeyData);
ApiKeyData.initFromContext(proxyApiKeyData, context);
proxy.getApiKeyStore().assignPerRequestApiKey(proxyApiKeyData);
}

private void handleRateLimitHit(RateLimitResult result) {
// Returning an error similar to the Azure format.
ErrorData rateLimitError = new ErrorData();
Expand Down Expand Up @@ -210,9 +208,6 @@ void handleRequestBody(Buffer requestBody) {

try {
ProxyUtil.collectAttachedFiles(tree, this::processAttachedFile);
// update api key data after processing attachments
ApiKeyData destApiKeyData = context.getProxyApiKeyData();
proxy.getApiKeyStore().updatePerRequestApiKeyData(destApiKeyData);
} catch (HttpException e) {
respond(e.getStatus(), e.getMessage());
log.warn("Can't collect attached files. Trace: {}. Span: {}. Error: {}",
Expand Down Expand Up @@ -629,14 +624,8 @@ private Future<Void> respond(HttpStatus status, Object result) {

private void finalizeRequest() {
proxy.getTokenStatsTracker().endSpan(context);
ApiKeyData proxyApiKeyData = context.getProxyApiKeyData();
if (proxyApiKeyData != null) {
proxy.getApiKeyStore().invalidatePerRequestApiKey(proxyApiKeyData)
.onSuccess(invalidated -> {
if (!invalidated) {
log.warn("Per request is not removed: {}", proxyApiKeyData.getPerRequestKey());
}
}).onFailure(error -> log.error("error occurred on invalidating per-request key", error));
if (context.getProxyApiKeyData() != null) {
proxy.getApiKeyStore().invalidateApiKey(context.getProxyApiKeyData());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
public enum ResourceType {
FILE("files"), CONVERSATION("conversations"), PROMPT("prompts"), LIMIT("limits"),
SHARED_WITH_ME("shared_with_me"), SHARED_BY_ME("shared_by_me"), INVITATION("invitations"),
PUBLICATION("publications"), RULES("rules"), API_KEY_DATA("api_key_data");
PUBLICATION("publications"), RULES("rules");

private final String group;

Expand Down
144 changes: 23 additions & 121 deletions src/main/java/com/epam/aidial/core/security/ApiKeyStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,155 +2,57 @@

import com.epam.aidial.core.config.ApiKeyData;
import com.epam.aidial.core.config.Key;
import com.epam.aidial.core.data.ResourceType;
import com.epam.aidial.core.service.LockService;
import com.epam.aidial.core.service.ResourceService;
import com.epam.aidial.core.storage.ResourceDescription;
import com.epam.aidial.core.util.ProxyUtil;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import static com.epam.aidial.core.security.ApiKeyGenerator.generateKey;
import static com.epam.aidial.core.storage.BlobStorageUtil.PATH_SEPARATOR;

/**
* The store keeps per request and project API key data.
* <p>
* Per request key is assigned during the request and terminated in the end of the request.
* Project keys are hosted by external secure storage and might be periodically updated by {@link com.epam.aidial.core.config.FileConfigStore}.
* </p>
*/
@Slf4j
@AllArgsConstructor
public class ApiKeyStore {

public static final String API_KEY_DATA_BUCKET = "api_key_data";
public static final String API_KEY_DATA_LOCATION = API_KEY_DATA_BUCKET + PATH_SEPARATOR;

private final ResourceService resourceService;

private final LockService lockService;

private final Vertx vertx;

/**
* Project API keys are hosted in the secure storage.
*/
private final Map<String, ApiKeyData> keys = new HashMap<>();

/**
* Assigns a new generated per request key to the {@link ApiKeyData}.
* <p>
* Note. The method is blocking and shouldn't be run in the event loop thread.
* </p>
*/
public synchronized void assignPerRequestApiKey(ApiKeyData data) {
lockService.underBucketLock(API_KEY_DATA_LOCATION, () -> {
ResourceDescription resource = generateApiKey();
String apiKey = resource.getName();
data.setPerRequestKey(apiKey);
String json = ProxyUtil.convertToString(data);
if (resourceService.putResource(resource, json, false, false) == null) {
throw new IllegalStateException(String.format("API key %s already exists in the storage", apiKey));
}
return apiKey;
});
public synchronized void assignApiKey(ApiKeyData data) {
String apiKey = generateApiKey();
keys.put(apiKey, data);
data.setPerRequestKey(apiKey);
}

/**
* Returns API key data for the given key.
*
* @param key API key could be either project or per request key.
* @return the future of data associated with the given key.
*/
public synchronized Future<ApiKeyData> getApiKeyData(String key) {
ApiKeyData apiKeyData = keys.get(key);
if (apiKeyData != null) {
return Future.succeededFuture(apiKeyData);
}
ResourceDescription resource = toResource(key);
return vertx.executeBlocking(() -> ProxyUtil.convertToObject(resourceService.getResource(resource), ApiKeyData.class));
public synchronized ApiKeyData getApiKeyData(String key) {
return keys.get(key);
}

/**
* Invalidates per request API key.
* If api key belongs to a project the operation will not have affect.
*
* @param apiKeyData associated with the key to be invalidated.
* @return the future of the invalidation result: <code>true</code> means the key is successfully invalidated.
*/
public Future<Boolean> invalidatePerRequestApiKey(ApiKeyData apiKeyData) {
String apiKey = apiKeyData.getPerRequestKey();
if (apiKey != null) {
ResourceDescription resource = toResource(apiKey);
return vertx.executeBlocking(() -> resourceService.deleteResource(resource));
public synchronized void invalidateApiKey(ApiKeyData apiKeyData) {
if (apiKeyData.getPerRequestKey() != null) {
keys.remove(apiKeyData.getPerRequestKey());
}
return Future.succeededFuture(true);
}

/**
* Adds new project keys from the secure storage and removes previous project keys if any.
* <p>
* Note. The method is blocking and shouldn't be run in the event loop thread.
* </p>
*
* @param projectKeys new projects to be added to the store.
*/
public synchronized void addProjectKeys(Map<String, Key> projectKeys) {
keys.clear();
lockService.underBucketLock(API_KEY_DATA_LOCATION, () -> {
for (Map.Entry<String, Key> entry : projectKeys.entrySet()) {
String apiKey = entry.getKey();
Key value = entry.getValue();
ResourceDescription resource = toResource(apiKey);
if (resourceService.hasResource(resource)) {
resource = generateApiKey();
apiKey = resource.getName();
}
value.setKey(apiKey);
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setOriginalKey(value);
keys.put(apiKey, apiKeyData);
keys.values().removeIf(apiKeyData -> apiKeyData.getPerRequestKey() == null);
for (Map.Entry<String, Key> entry : projectKeys.entrySet()) {
String key = entry.getKey();
Key value = entry.getValue();
if (keys.containsKey(key)) {
key = generateApiKey();
}
return null;
});
}

/**
* Updates data associated with per request key.
* If api key belongs to a project the operation will not have affect.
*
* @param apiKeyData per request key data.
*/
public void updatePerRequestApiKeyData(ApiKeyData apiKeyData) {
String apiKey = apiKeyData.getPerRequestKey();
if (apiKey == null) {
return;
value.setKey(key);
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setOriginalKey(value);
keys.put(key, apiKeyData);
}
String json = ProxyUtil.convertToString(apiKeyData);
ResourceDescription resource = toResource(apiKey);
resourceService.putResource(resource, json, true, false);
}

private ResourceDescription generateApiKey() {
private String generateApiKey() {
String apiKey = generateKey();
ResourceDescription resource = toResource(apiKey);
while (resourceService.hasResource(resource) || keys.containsKey(apiKey)) {
while (keys.containsKey(apiKey)) {
log.warn("duplicate API key is found. Trying to generate a new one");
apiKey = generateKey();
resource = toResource(apiKey);
}
return resource;
}

private static ResourceDescription toResource(String apiKey) {
return ResourceDescription.fromDecoded(
ResourceType.API_KEY_DATA, API_KEY_DATA_BUCKET, API_KEY_DATA_LOCATION, apiKey);
return apiKey;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ private ResourceFolderMetadata getFolderMetadata(ResourceDescription descriptor,
return new ResourceFolderMetadata(descriptor, resources).setNextToken(set.getNextMarker());
}

@Nullable
public ResourceItemMetadata getResourceMetadata(ResourceDescription descriptor) {
if (descriptor.isFolder()) {
throw new IllegalArgumentException("Resource folder: " + descriptor.getUrl());
Expand Down
Loading
Loading