Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement per request API keys through stateful layer #183 #312

Merged
merged 3 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/main/java/com/epam/aidial/core/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ void start() throws Exception {
vertx = Vertx.vertx(vertxOptions);
client = vertx.createHttpClient(new HttpClientOptions(settings("client")));

ApiKeyStore apiKeyStore = new ApiKeyStore();
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);
LogStore logStore = new GfLogStore(vertx);
UpstreamBalancer upstreamBalancer = new UpstreamBalancer();

Expand All @@ -121,6 +119,9 @@ void start() throws Exception {
AccessService accessService = new AccessService(encryptionService, shareService, publicationService, settings("access"));
RateLimiter rateLimiter = new RateLimiter(vertx, resourceService);

ApiKeyStore apiKeyStore = new ApiKeyStore(resourceService, lockService, vertx);
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);

proxy = new Proxy(vertx, client, configStore, logStore,
rateLimiter, upstreamBalancer, accessTokenValidator,
storage, encryptionService, apiKeyStore, tokenStatsTracker, resourceService, invitationService,
Expand Down
46 changes: 28 additions & 18 deletions src/main/java/com/epam/aidial/core/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -147,43 +147,53 @@ private void handleRequest(HttpServerRequest request) {
String traceId = spanContext.getTraceId();
String spanId = spanContext.getSpanId();
log.debug("Authorization header: {}", authorization);
ApiKeyData apiKeyData;
Future<AuthorizationResult> authorizationResultFuture;

request.pause();
if (apiKey == null && authorization == null) {
respond(request, HttpStatus.UNAUTHORIZED, "At least API-KEY or Authorization header must be provided");
return;
} else if (apiKey != null && authorization != null && !apiKey.equals(extractTokenFromHeader(authorization))) {
respond(request, HttpStatus.BAD_REQUEST, "Either API-KEY or Authorization header must be provided but not both");
return;
} else if (apiKey != null) {
apiKeyData = apiKeyStore.getApiKeyData(apiKey);
// Special case handling. OpenAI client sends both API key and Auth headers even if a caller sets just API Key only
// Auth header is set to the same value as API Key header
// ignore auth header in this case
authorization = null;
if (apiKeyData == null) {
respond(request, HttpStatus.UNAUTHORIZED, "Unknown api key");
return;
}
authorizationResultFuture = apiKeyStore.getApiKeyData(apiKey)
.onFailure(error -> onGettingApiKeyDataFailure(error, request))
.compose(apiKeyData -> {
if (apiKeyData == null) {
String errorMessage = "Unknown api key";
respond(request, HttpStatus.UNAUTHORIZED, errorMessage);
return Future.failedFuture(errorMessage);
}
return Future.succeededFuture(new AuthorizationResult(apiKeyData, null));
});
} else {
apiKeyData = new ApiKeyData();
authorizationResultFuture = tokenValidator.extractClaims(authorization)
.onFailure(error -> onExtractClaimsFailure(error, request))
.map(extractedClaims -> new AuthorizationResult(new ApiKeyData(), extractedClaims));
}

request.pause();
Future<ExtractedClaims> extractedClaims = tokenValidator.extractClaims(authorization);

extractedClaims.onFailure(error -> onExtractClaimsFailure(error, request))
.compose(claims -> onExtractClaimsSuccess(claims, config, request, apiKeyData, traceId, spanId))
authorizationResultFuture.compose(result -> processAuthorizationResult(result.extractedClaims, config, request, result.apiKeyData, traceId, spanId))
.onComplete(ignore -> request.resume());
}

private record AuthorizationResult(ApiKeyData apiKeyData, ExtractedClaims extractedClaims) {

}

private void onExtractClaimsFailure(Throwable error, HttpServerRequest request) {
log.error("Can't extract claims from authorization header", error);
respond(request, HttpStatus.UNAUTHORIZED, "Bad Authorization header");
}

private void onGettingApiKeyDataFailure(Throwable error, HttpServerRequest request) {
log.error("Can't find data associated with API key", error);
respond(request, HttpStatus.UNAUTHORIZED, "Bad Authorization header");
}

@SneakyThrows
private Future<?> onExtractClaimsSuccess(ExtractedClaims extractedClaims, Config config,
HttpServerRequest request, ApiKeyData apiKeyData, String traceId, String spanId) {
private Future<?> processAuthorizationResult(ExtractedClaims extractedClaims, Config config,
HttpServerRequest request, ApiKeyData apiKeyData, String traceId, String spanId) {
Future<?> future;
try {
ProxyContext context = new ProxyContext(config, request, apiKeyData, extractedClaims, traceId, spanId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,28 @@ private void handleRateLimitSuccess(String deploymentId) {
return;
}

ApiKeyData proxyApiKeyData = new ApiKeyData();
context.setProxyApiKeyData(proxyApiKeyData);
ApiKeyData.initFromContext(proxyApiKeyData, context);
proxy.getApiKeyStore().assignApiKey(proxyApiKeyData);

proxy.getTokenStatsTracker().startSpan(context);

context.getRequest().body()
.onSuccess(body -> proxy.getVertx().executeBlocking(() -> {
// run setting up api key data in the worker thread
setupProxyApiKeyData();
handleRequestBody(body);
return null;
}))
}).onFailure(this::handleError))
.onFailure(this::handleRequestBodyError);
}

/**
* The method uses blocking calls and should not be used in the event loop thread.
*/
private void setupProxyApiKeyData() {
ApiKeyData proxyApiKeyData = new ApiKeyData();
context.setProxyApiKeyData(proxyApiKeyData);
ApiKeyData.initFromContext(proxyApiKeyData, context);
proxy.getApiKeyStore().assignPerRequestApiKey(proxyApiKeyData);
}

private void handleRateLimitHit(RateLimitResult result) {
// Returning an error similar to the Azure format.
ErrorData rateLimitError = new ErrorData();
Expand Down Expand Up @@ -203,6 +210,9 @@ void handleRequestBody(Buffer requestBody) {

try {
ProxyUtil.collectAttachedFiles(tree, this::processAttachedFile);
// update api key data after processing attachments
ApiKeyData destApiKeyData = context.getProxyApiKeyData();
proxy.getApiKeyStore().updatePerRequestApiKeyData(destApiKeyData);
} catch (HttpException e) {
respond(e.getStatus(), e.getMessage());
log.warn("Can't collect attached files. Trace: {}. Span: {}. Error: {}",
Expand Down Expand Up @@ -619,8 +629,14 @@ private Future<Void> respond(HttpStatus status, Object result) {

private void finalizeRequest() {
proxy.getTokenStatsTracker().endSpan(context);
if (context.getProxyApiKeyData() != null) {
proxy.getApiKeyStore().invalidateApiKey(context.getProxyApiKeyData());
ApiKeyData proxyApiKeyData = context.getProxyApiKeyData();
if (proxyApiKeyData != null) {
proxy.getApiKeyStore().invalidatePerRequestApiKey(proxyApiKeyData)
.onSuccess(invalidated -> {
if (!invalidated) {
log.warn("Per request is not removed: {}", proxyApiKeyData.getPerRequestKey());
}
}).onFailure(error -> log.error("error occurred on invalidating per-request key", error));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
public enum ResourceType {
FILE("files"), CONVERSATION("conversations"), PROMPT("prompts"), LIMIT("limits"),
SHARED_WITH_ME("shared_with_me"), SHARED_BY_ME("shared_by_me"), INVITATION("invitations"),
PUBLICATION("publications"), RULES("rules");
PUBLICATION("publications"), RULES("rules"), API_KEY_DATA("api_key_data");

private final String group;

Expand Down
144 changes: 121 additions & 23 deletions src/main/java/com/epam/aidial/core/security/ApiKeyStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,155 @@

import com.epam.aidial.core.config.ApiKeyData;
import com.epam.aidial.core.config.Key;
import com.epam.aidial.core.data.ResourceType;
import com.epam.aidial.core.service.LockService;
import com.epam.aidial.core.service.ResourceService;
import com.epam.aidial.core.storage.ResourceDescription;
import com.epam.aidial.core.util.ProxyUtil;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import static com.epam.aidial.core.security.ApiKeyGenerator.generateKey;
import static com.epam.aidial.core.storage.BlobStorageUtil.PATH_SEPARATOR;

/**
* The store keeps per request and project API key data.
* <p>
* Per request key is assigned during the request and terminated in the end of the request.
* Project keys are hosted by external secure storage and might be periodically updated by {@link com.epam.aidial.core.config.FileConfigStore}.
* </p>
*/
@Slf4j
@AllArgsConstructor
Maxim-Gadalov marked this conversation as resolved.
Show resolved Hide resolved
public class ApiKeyStore {

public static final String API_KEY_DATA_BUCKET = "api_key_data";
public static final String API_KEY_DATA_LOCATION = API_KEY_DATA_BUCKET + PATH_SEPARATOR;

private final ResourceService resourceService;

private final LockService lockService;

private final Vertx vertx;

/**
* Project API keys are hosted in the secure storage.
*/
private final Map<String, ApiKeyData> keys = new HashMap<>();

public synchronized void assignApiKey(ApiKeyData data) {
String apiKey = generateApiKey();
keys.put(apiKey, data);
data.setPerRequestKey(apiKey);
/**
* Assigns a new generated per request key to the {@link ApiKeyData}.
* <p>
* Note. The method is blocking and shouldn't be run in the event loop thread.
* </p>
*/
public synchronized void assignPerRequestApiKey(ApiKeyData data) {
lockService.underBucketLock(API_KEY_DATA_LOCATION, () -> {
ResourceDescription resource = generateApiKey();
String apiKey = resource.getName();
data.setPerRequestKey(apiKey);
String json = ProxyUtil.convertToString(data);
if (resourceService.putResource(resource, json, false, false) == null) {
throw new IllegalStateException(String.format("API key %s already exists in the storage", apiKey));
}
return apiKey;
});
}

public synchronized ApiKeyData getApiKeyData(String key) {
return keys.get(key);
/**
* Returns API key data for the given key.
*
* @param key API key could be either project or per request key.
* @return the future of data associated with the given key.
*/
public synchronized Future<ApiKeyData> getApiKeyData(String key) {
ApiKeyData apiKeyData = keys.get(key);
if (apiKeyData != null) {
return Future.succeededFuture(apiKeyData);
}
ResourceDescription resource = toResource(key);
return vertx.executeBlocking(() -> ProxyUtil.convertToObject(resourceService.getResource(resource), ApiKeyData.class));
}

public synchronized void invalidateApiKey(ApiKeyData apiKeyData) {
if (apiKeyData.getPerRequestKey() != null) {
keys.remove(apiKeyData.getPerRequestKey());
/**
* Invalidates per request API key.
* If api key belongs to a project the operation will not have affect.
*
* @param apiKeyData associated with the key to be invalidated.
* @return the future of the invalidation result: <code>true</code> means the key is successfully invalidated.
*/
public Future<Boolean> invalidatePerRequestApiKey(ApiKeyData apiKeyData) {
String apiKey = apiKeyData.getPerRequestKey();
if (apiKey != null) {
ResourceDescription resource = toResource(apiKey);
return vertx.executeBlocking(() -> resourceService.deleteResource(resource));
}
return Future.succeededFuture(true);
}

/**
* Adds new project keys from the secure storage and removes previous project keys if any.
* <p>
* Note. The method is blocking and shouldn't be run in the event loop thread.
* </p>
*
* @param projectKeys new projects to be added to the store.
*/
public synchronized void addProjectKeys(Map<String, Key> projectKeys) {
keys.values().removeIf(apiKeyData -> apiKeyData.getPerRequestKey() == null);
for (Map.Entry<String, Key> entry : projectKeys.entrySet()) {
String key = entry.getKey();
Key value = entry.getValue();
if (keys.containsKey(key)) {
key = generateApiKey();
keys.clear();
lockService.underBucketLock(API_KEY_DATA_LOCATION, () -> {
for (Map.Entry<String, Key> entry : projectKeys.entrySet()) {
String apiKey = entry.getKey();
Key value = entry.getValue();
ResourceDescription resource = toResource(apiKey);
if (resourceService.hasResource(resource)) {
resource = generateApiKey();
apiKey = resource.getName();
}
value.setKey(apiKey);
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setOriginalKey(value);
keys.put(apiKey, apiKeyData);
}
value.setKey(key);
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setOriginalKey(value);
keys.put(key, apiKeyData);
return null;
});
}

/**
* Updates data associated with per request key.
* If api key belongs to a project the operation will not have affect.
*
* @param apiKeyData per request key data.
*/
public void updatePerRequestApiKeyData(ApiKeyData apiKeyData) {
String apiKey = apiKeyData.getPerRequestKey();
if (apiKey == null) {
return;
}
String json = ProxyUtil.convertToString(apiKeyData);
ResourceDescription resource = toResource(apiKey);
resourceService.putResource(resource, json, true, false);
}

private String generateApiKey() {
private ResourceDescription generateApiKey() {
String apiKey = generateKey();
while (keys.containsKey(apiKey)) {
ResourceDescription resource = toResource(apiKey);
while (resourceService.hasResource(resource) || keys.containsKey(apiKey)) {
log.warn("duplicate API key is found. Trying to generate a new one");
apiKey = generateKey();
resource = toResource(apiKey);
}
return apiKey;
return resource;
}

private static ResourceDescription toResource(String apiKey) {
return ResourceDescription.fromDecoded(
ResourceType.API_KEY_DATA, API_KEY_DATA_BUCKET, API_KEY_DATA_LOCATION, apiKey);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ private ResourceFolderMetadata getFolderMetadata(ResourceDescription descriptor,
return new ResourceFolderMetadata(descriptor, resources).setNextToken(set.getNextMarker());
}

@Nullable
public ResourceItemMetadata getResourceMetadata(ResourceDescription descriptor) {
if (descriptor.isFolder()) {
throw new IllegalArgumentException("Resource folder: " + descriptor.getUrl());
Expand Down
Loading
Loading