Skip to content

Commit

Permalink
fix: performance API key store degradation (#325)
Browse files Browse the repository at this point in the history
Co-authored-by: Aliaksandr Stsiapanay <aliaksandr_stsiapanay@epam.com>
  • Loading branch information
astsiapanay and astsiapanay authored May 6, 2024
1 parent 5bbe814 commit fce5ee1
Showing 7 changed files with 47 additions and 83 deletions.
2 changes: 1 addition & 1 deletion src/main/java/com/epam/aidial/core/AiDial.java
Original file line number Diff line number Diff line change
@@ -119,7 +119,7 @@ void start() throws Exception {
AccessService accessService = new AccessService(encryptionService, shareService, publicationService, settings("access"));
RateLimiter rateLimiter = new RateLimiter(vertx, resourceService);

ApiKeyStore apiKeyStore = new ApiKeyStore(resourceService, lockService, vertx);
ApiKeyStore apiKeyStore = new ApiKeyStore(resourceService, vertx);
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);

proxy = new Proxy(vertx, client, configStore, logStore,
Original file line number Diff line number Diff line change
@@ -125,15 +125,14 @@ private void handleRateLimitSuccess(String deploymentId) {
return;
}

setupProxyApiKeyData();
proxy.getTokenStatsTracker().startSpan(context);

context.getRequest().body()
.onSuccess(body -> proxy.getVertx().executeBlocking(() -> {
// run setting up api key data in the worker thread
setupProxyApiKeyData();
handleRequestBody(body);
return null;
}).onFailure(this::handleError))
}, false).onFailure(this::handleError))
.onFailure(this::handleRequestBodyError);
}

@@ -144,7 +143,6 @@ private void setupProxyApiKeyData() {
ApiKeyData proxyApiKeyData = new ApiKeyData();
context.setProxyApiKeyData(proxyApiKeyData);
ApiKeyData.initFromContext(proxyApiKeyData, context);
proxy.getApiKeyStore().assignPerRequestApiKey(proxyApiKeyData);
}

private void handleRateLimitHit(RateLimitResult result) {
@@ -210,9 +208,9 @@ void handleRequestBody(Buffer requestBody) {

try {
ProxyUtil.collectAttachedFiles(tree, this::processAttachedFile);
// update api key data after processing attachments
// assign api key data after processing attachments
ApiKeyData destApiKeyData = context.getProxyApiKeyData();
proxy.getApiKeyStore().updatePerRequestApiKeyData(destApiKeyData);
proxy.getApiKeyStore().assignPerRequestApiKey(destApiKeyData);
} catch (HttpException e) {
respond(e.getStatus(), e.getMessage());
log.warn("Can't collect attached files. Trace: {}. Span: {}. Error: {}",
6 changes: 3 additions & 3 deletions src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ public Future<Void> increase(ProxyContext context) {

String tokensPath = getPathToTokens(context.getDeployment().getName());
ResourceDescription resourceDescription = getResourceDescription(context, tokensPath);
return vertx.executeBlocking(() -> updateTokenLimit(resourceDescription, usage.getTotalTokens()));
return vertx.executeBlocking(() -> updateTokenLimit(resourceDescription, usage.getTotalTokens()), false);
} catch (Throwable e) {
return Future.failedFuture(e);
}
@@ -79,7 +79,7 @@ public Future<RateLimitResult> limit(ProxyContext context) {
return Future.succeededFuture(new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied"));
}

return vertx.executeBlocking(() -> checkLimit(context, limit));
return vertx.executeBlocking(() -> checkLimit(context, limit), false);
} catch (Throwable e) {
return Future.failedFuture(e);
}
@@ -104,7 +104,7 @@ public Future<LimitStats> getLimitStats(String deploymentName, ProxyContext cont
context.getUserSub(), deploymentName);
return Future.succeededFuture();
}
return vertx.executeBlocking(() -> getLimitStats(context, limit, deploymentName));
return vertx.executeBlocking(() -> getLimitStats(context, limit, deploymentName), false);
} catch (Throwable e) {
return Future.failedFuture(e);
}
91 changes: 27 additions & 64 deletions src/main/java/com/epam/aidial/core/security/ApiKeyStore.java
Original file line number Diff line number Diff line change
@@ -3,13 +3,11 @@
import com.epam.aidial.core.config.ApiKeyData;
import com.epam.aidial.core.config.Key;
import com.epam.aidial.core.data.ResourceType;
import com.epam.aidial.core.service.LockService;
import com.epam.aidial.core.service.ResourceService;
import com.epam.aidial.core.storage.ResourceDescription;
import com.epam.aidial.core.util.ProxyUtil;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.HashMap;
@@ -26,40 +24,39 @@
* </p>
*/
@Slf4j
@AllArgsConstructor
public class ApiKeyStore {

public static final String API_KEY_DATA_BUCKET = "api_key_data";
public static final String API_KEY_DATA_LOCATION = API_KEY_DATA_BUCKET + PATH_SEPARATOR;

private final ResourceService resourceService;

private final LockService lockService;

private final Vertx vertx;

public ApiKeyStore(ResourceService resourceService, Vertx vertx) {
this.resourceService = resourceService;
this.vertx = vertx;
}

/**
* Project API keys are hosted in the secure storage.
*/
private final Map<String, ApiKeyData> keys = new HashMap<>();
private volatile Map<String, ApiKeyData> keys = new HashMap<>();

/**
* Assigns a new generated per request key to the {@link ApiKeyData}.
* <p>
* Note. The method is blocking and shouldn't be run in the event loop thread.
* </p>
*/
public synchronized void assignPerRequestApiKey(ApiKeyData data) {
lockService.underBucketLock(API_KEY_DATA_LOCATION, () -> {
ResourceDescription resource = generateApiKey();
String apiKey = resource.getName();
data.setPerRequestKey(apiKey);
String json = ProxyUtil.convertToString(data);
if (resourceService.putResource(resource, json, false, false) == null) {
throw new IllegalStateException(String.format("API key %s already exists in the storage", apiKey));
}
return apiKey;
});
public void assignPerRequestApiKey(ApiKeyData data) {
String perRequestKey = generateKey();
ResourceDescription resource = toResource(perRequestKey);
data.setPerRequestKey(perRequestKey);
String json = ProxyUtil.convertToString(data);
if (resourceService.putResource(resource, json, false, false) == null) {
throw new IllegalStateException(String.format("API key %s already exists in the storage", perRequestKey));
}
}

/**
@@ -68,13 +65,13 @@ public synchronized void assignPerRequestApiKey(ApiKeyData data) {
* @param key API key could be either project or per request key.
* @return the future of data associated with the given key.
*/
public synchronized Future<ApiKeyData> getApiKeyData(String key) {
public Future<ApiKeyData> getApiKeyData(String key) {
ApiKeyData apiKeyData = keys.get(key);
if (apiKeyData != null) {
return Future.succeededFuture(apiKeyData);
}
ResourceDescription resource = toResource(key);
return vertx.executeBlocking(() -> ProxyUtil.convertToObject(resourceService.getResource(resource), ApiKeyData.class));
return vertx.executeBlocking(() -> ProxyUtil.convertToObject(resourceService.getResource(resource), ApiKeyData.class), false);
}

/**
@@ -88,7 +85,7 @@ public Future<Boolean> invalidatePerRequestApiKey(ApiKeyData apiKeyData) {
String apiKey = apiKeyData.getPerRequestKey();
if (apiKey != null) {
ResourceDescription resource = toResource(apiKey);
return vertx.executeBlocking(() -> resourceService.deleteResource(resource));
return vertx.executeBlocking(() -> resourceService.deleteResource(resource), false);
}
return Future.succeededFuture(true);
}
@@ -101,51 +98,17 @@ public Future<Boolean> invalidatePerRequestApiKey(ApiKeyData apiKeyData) {
*
* @param projectKeys new projects to be added to the store.
*/
public synchronized void addProjectKeys(Map<String, Key> projectKeys) {
keys.clear();
lockService.underBucketLock(API_KEY_DATA_LOCATION, () -> {
for (Map.Entry<String, Key> entry : projectKeys.entrySet()) {
String apiKey = entry.getKey();
Key value = entry.getValue();
ResourceDescription resource = toResource(apiKey);
if (resourceService.hasResource(resource)) {
resource = generateApiKey();
apiKey = resource.getName();
}
value.setKey(apiKey);
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setOriginalKey(value);
keys.put(apiKey, apiKeyData);
}
return null;
});
}

/**
* Updates data associated with per request key.
* If api key belongs to a project the operation will not have affect.
*
* @param apiKeyData per request key data.
*/
public void updatePerRequestApiKeyData(ApiKeyData apiKeyData) {
String apiKey = apiKeyData.getPerRequestKey();
if (apiKey == null) {
return;
}
String json = ProxyUtil.convertToString(apiKeyData);
ResourceDescription resource = toResource(apiKey);
resourceService.putResource(resource, json, true, false);
}

private ResourceDescription generateApiKey() {
String apiKey = generateKey();
ResourceDescription resource = toResource(apiKey);
while (resourceService.hasResource(resource) || keys.containsKey(apiKey)) {
log.warn("duplicate API key is found. Trying to generate a new one");
apiKey = generateKey();
resource = toResource(apiKey);
public void addProjectKeys(Map<String, Key> projectKeys) {
Map<String, ApiKeyData> apiKeyDataMap = new HashMap<>();
for (Map.Entry<String, Key> entry : projectKeys.entrySet()) {
String apiKey = entry.getKey();
Key value = entry.getValue();
value.setKey(apiKey);
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setOriginalKey(value);
apiKeyDataMap.put(apiKey, apiKeyData);
}
return resource;
keys = apiKeyDataMap;
}

private static ResourceDescription toResource(String apiKey) {
Original file line number Diff line number Diff line change
@@ -173,6 +173,7 @@ public void testHandler_Ok() {
when(request.headers()).thenReturn(headers);
when(context.getDeployment()).thenReturn(application);
when(proxy.getTokenStatsTracker()).thenReturn(tokenStatsTracker);
when(context.getApiKeyData()).thenReturn(new ApiKeyData());

controller.handle("app1", "chat/completions");

13 changes: 7 additions & 6 deletions src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
@@ -192,7 +193,7 @@ public void testLimit_ApiKeySuccess_KeyNotFound() {
model.setName("model");
proxyContext.setDeployment(model);

when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});
@@ -224,7 +225,7 @@ public void testLimit_ApiKeySuccess_KeyExist() {
model.setName("model");
proxyContext.setDeployment(model);

when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});
@@ -277,7 +278,7 @@ public void testGetLimitStats_ApiKey() {
model.setName("model");
proxyContext.setDeployment(model);

when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});
@@ -351,7 +352,7 @@ public void testLimit_User_LimitFound() {
model.setName("model");
proxyContext.setDeployment(model);

when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});
@@ -393,7 +394,7 @@ public void testLimit_User_DefaultLimit() {
model.setName("model");
proxyContext.setDeployment(model);

when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});
@@ -449,7 +450,7 @@ public void testLimit_User_RequestLimit() {
model.setName("model");
proxyContext.setDeployment(model);

when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
@@ -95,7 +96,7 @@ public void beforeEach() {
}
""";
ResourceService resourceService = new ResourceService(vertx, redissonClient, blobStorage, lockService, new JsonObject(resourceConfig), null);
store = new ApiKeyStore(resourceService, lockService, vertx);
store = new ApiKeyStore(resourceService, vertx);
}

@Test
@@ -107,7 +108,7 @@ public void testAssignApiKey() {

@Test
public void testAddProjectKeys() {
when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});
@@ -147,7 +148,7 @@ public void testGetApiKeyData() {

assertNotNull(apiKeyData.getPerRequestKey());

when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

0 comments on commit fce5ee1

Please sign in to comment.