Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement per request API keys through stateful layer #183 #312

Merged
merged 3 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/main/java/com/epam/aidial/core/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ void start() throws Exception {
vertx = Vertx.vertx(vertxOptions);
client = vertx.createHttpClient(new HttpClientOptions(settings("client")));

ApiKeyStore apiKeyStore = new ApiKeyStore();
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);
LogStore logStore = new GfLogStore(vertx);
UpstreamBalancer upstreamBalancer = new UpstreamBalancer();

Expand All @@ -121,6 +119,9 @@ void start() throws Exception {
AccessService accessService = new AccessService(encryptionService, shareService, publicationService, settings("access"));
RateLimiter rateLimiter = new RateLimiter(vertx, resourceService);

ApiKeyStore apiKeyStore = new ApiKeyStore(resourceService, lockService, vertx);
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);

proxy = new Proxy(vertx, client, configStore, logStore,
rateLimiter, upstreamBalancer, accessTokenValidator,
storage, encryptionService, apiKeyStore, tokenStatsTracker, resourceService, invitationService,
Expand Down
46 changes: 28 additions & 18 deletions src/main/java/com/epam/aidial/core/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -147,43 +147,53 @@ private void handleRequest(HttpServerRequest request) {
String traceId = spanContext.getTraceId();
String spanId = spanContext.getSpanId();
log.debug("Authorization header: {}", authorization);
ApiKeyData apiKeyData;
Future<AuthorizationResult> authorizationResultFuture;

request.pause();
if (apiKey == null && authorization == null) {
respond(request, HttpStatus.UNAUTHORIZED, "At least API-KEY or Authorization header must be provided");
return;
} else if (apiKey != null && authorization != null && !apiKey.equals(extractTokenFromHeader(authorization))) {
respond(request, HttpStatus.BAD_REQUEST, "Either API-KEY or Authorization header must be provided but not both");
return;
} else if (apiKey != null) {
apiKeyData = apiKeyStore.getApiKeyData(apiKey);
// Special case handling. OpenAI client sends both API key and Auth headers even if a caller sets just API Key only
// Auth header is set to the same value as API Key header
// ignore auth header in this case
authorization = null;
if (apiKeyData == null) {
respond(request, HttpStatus.UNAUTHORIZED, "Unknown api key");
return;
}
authorizationResultFuture = apiKeyStore.getApiKeyData(apiKey)
.onFailure(error -> onGettingApiKeyDataFailure(error, request))
.compose(apiKeyData -> {
if (apiKeyData == null) {
String errorMessage = "Unknown api key";
respond(request, HttpStatus.UNAUTHORIZED, errorMessage);
return Future.failedFuture(errorMessage);
}
return Future.succeededFuture(new AuthorizationResult(apiKeyData, null));
});
} else {
apiKeyData = new ApiKeyData();
authorizationResultFuture = tokenValidator.extractClaims(authorization)
.onFailure(error -> onExtractClaimsFailure(error, request))
.map(extractedClaims -> new AuthorizationResult(new ApiKeyData(), extractedClaims));
}

request.pause();
Future<ExtractedClaims> extractedClaims = tokenValidator.extractClaims(authorization);

extractedClaims.onFailure(error -> onExtractClaimsFailure(error, request))
.compose(claims -> onExtractClaimsSuccess(claims, config, request, apiKeyData, traceId, spanId))
authorizationResultFuture.compose(result -> processAuthorizationResult(result.extractedClaims, config, request, result.apiKeyData, traceId, spanId))
.onComplete(ignore -> request.resume());
}

private record AuthorizationResult(ApiKeyData apiKeyData, ExtractedClaims extractedClaims) {

}

private void onExtractClaimsFailure(Throwable error, HttpServerRequest request) {
log.error("Can't extract claims from authorization header", error);
respond(request, HttpStatus.UNAUTHORIZED, "Bad Authorization header");
}

private void onGettingApiKeyDataFailure(Throwable error, HttpServerRequest request) {
log.error("Can't find data associated with API key", error);
respond(request, HttpStatus.UNAUTHORIZED, "Bad Authorization header");
}

@SneakyThrows
private Future<?> onExtractClaimsSuccess(ExtractedClaims extractedClaims, Config config,
HttpServerRequest request, ApiKeyData apiKeyData, String traceId, String spanId) {
private Future<?> processAuthorizationResult(ExtractedClaims extractedClaims, Config config,
HttpServerRequest request, ApiKeyData apiKeyData, String traceId, String spanId) {
Future<?> future;
try {
ProxyContext context = new ProxyContext(config, request, apiKeyData, extractedClaims, traceId, spanId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,22 @@ private void handleRateLimitSuccess(String deploymentId) {
return;
}

ApiKeyData proxyApiKeyData = new ApiKeyData();
context.setProxyApiKeyData(proxyApiKeyData);
ApiKeyData.initFromContext(proxyApiKeyData, context);
proxy.getApiKeyStore().assignApiKey(proxyApiKeyData);

proxy.getTokenStatsTracker().startSpan(context);

context.getRequest().body()
.onSuccess(body -> proxy.getVertx().executeBlocking(() -> {
handleRequestBody(body);
return null;
}))
.onFailure(this::handleRequestBodyError);
proxy.getVertx().executeBlocking(() -> {
Maxim-Gadalov marked this conversation as resolved.
Show resolved Hide resolved
ApiKeyData proxyApiKeyData = new ApiKeyData();
context.setProxyApiKeyData(proxyApiKeyData);
ApiKeyData.initFromContext(proxyApiKeyData, context);
proxy.getApiKeyStore().assignApiKey(proxyApiKeyData);

proxy.getTokenStatsTracker().startSpan(context);

context.getRequest().body()
.onSuccess(body -> proxy.getVertx().executeBlocking(() -> {
handleRequestBody(body);
return null;
}))
.onFailure(this::handleRequestBodyError);
return null;
}).onFailure(this::handleError);
}

private void handleRateLimitHit(RateLimitResult result) {
Expand Down Expand Up @@ -203,6 +206,9 @@ void handleRequestBody(Buffer requestBody) {

try {
ProxyUtil.collectAttachedFiles(tree, this::processAttachedFile);
// update api key data after processing attachments
ApiKeyData destApiKeyData = context.getProxyApiKeyData();
proxy.getApiKeyStore().updateApiKeyData(destApiKeyData);
} catch (HttpException e) {
respond(e.getStatus(), e.getMessage());
log.warn("Can't collect attached files. Trace: {}. Span: {}. Error: {}",
Expand Down Expand Up @@ -619,8 +625,14 @@ private Future<Void> respond(HttpStatus status, Object result) {

private void finalizeRequest() {
proxy.getTokenStatsTracker().endSpan(context);
if (context.getProxyApiKeyData() != null) {
proxy.getApiKeyStore().invalidateApiKey(context.getProxyApiKeyData());
ApiKeyData proxyApiKeyData = context.getProxyApiKeyData();
if (proxyApiKeyData != null) {
proxy.getApiKeyStore().invalidateApiKey(proxyApiKeyData)
.onSuccess(invalidated -> {
if (!invalidated) {
log.warn("Per request is not removed: {}", proxyApiKeyData.getPerRequestKey());
}
}).onFailure(error -> log.error("error occurred on invalidating per-request key: {}", error));
Maxim-Gadalov marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
public enum ResourceType {
FILE("files"), CONVERSATION("conversations"), PROMPT("prompts"), LIMIT("limits"),
SHARED_WITH_ME("shared_with_me"), SHARED_BY_ME("shared_by_me"), INVITATION("invitations"),
PUBLICATION("publications"), RULES("rules");
PUBLICATION("publications"), RULES("rules"), API_KEY_DATA("api_key_data");

private final String group;

Expand Down
104 changes: 82 additions & 22 deletions src/main/java/com/epam/aidial/core/security/ApiKeyStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,117 @@

import com.epam.aidial.core.config.ApiKeyData;
import com.epam.aidial.core.config.Key;
import com.epam.aidial.core.data.ResourceType;
import com.epam.aidial.core.service.LockService;
import com.epam.aidial.core.service.ResourceService;
import com.epam.aidial.core.storage.ResourceDescription;
import com.epam.aidial.core.util.ProxyUtil;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import javax.annotation.concurrent.GuardedBy;

import static com.epam.aidial.core.security.ApiKeyGenerator.generateKey;
import static com.epam.aidial.core.storage.BlobStorageUtil.PATH_SEPARATOR;

@Slf4j
@AllArgsConstructor
Maxim-Gadalov marked this conversation as resolved.
Show resolved Hide resolved
public class ApiKeyStore {

public static final String API_KEY_DATA_BUCKET = "api_key_data";
public static final String API_KEY_DATA_LOCATION = API_KEY_DATA_BUCKET + PATH_SEPARATOR;

private final ResourceService resourceService;

private final LockService lockService;

private final Vertx vertx;

/**
* API keys are captured from secure storage.
*/
@GuardedBy("this")
private final Map<String, ApiKeyData> keys = new HashMap<>();

public synchronized void assignApiKey(ApiKeyData data) {
String apiKey = generateApiKey();
keys.put(apiKey, data);
data.setPerRequestKey(apiKey);
lockService.underBucketLock(API_KEY_DATA_LOCATION, () -> {
ResourceDescription resource = generateApiKey();
String apiKey = resource.getName();
data.setPerRequestKey(apiKey);
String json = ProxyUtil.convertToString(data);
if (resourceService.putResource(resource, json, false, false) == null) {
throw new IllegalStateException(String.format("API key %s already exists in the storage", apiKey));
}
return apiKey;
});
}

public synchronized ApiKeyData getApiKeyData(String key) {
return keys.get(key);
public synchronized Future<ApiKeyData> getApiKeyData(String key) {
ApiKeyData apiKeyData = keys.get(key);
if (apiKeyData != null) {
return Future.succeededFuture(apiKeyData);
}
ResourceDescription resource = toResource(key);
return vertx.executeBlocking(() -> ProxyUtil.convertToObject(resourceService.getResource(resource), ApiKeyData.class));
}

public synchronized void invalidateApiKey(ApiKeyData apiKeyData) {
if (apiKeyData.getPerRequestKey() != null) {
keys.remove(apiKeyData.getPerRequestKey());
public Future<Boolean> invalidateApiKey(ApiKeyData apiKeyData) {
Maxim-Gadalov marked this conversation as resolved.
Show resolved Hide resolved
String apiKey = apiKeyData.getPerRequestKey();
if (apiKey != null) {
ResourceDescription resource = toResource(apiKey);
return vertx.executeBlocking(() -> resourceService.deleteResource(resource));
}
return Future.succeededFuture(true);
}

public synchronized void addProjectKeys(Map<String, Key> projectKeys) {
keys.values().removeIf(apiKeyData -> apiKeyData.getPerRequestKey() == null);
for (Map.Entry<String, Key> entry : projectKeys.entrySet()) {
String key = entry.getKey();
Key value = entry.getValue();
if (keys.containsKey(key)) {
key = generateApiKey();
keys.clear();
lockService.underBucketLock(API_KEY_DATA_LOCATION, () -> {
for (Map.Entry<String, Key> entry : projectKeys.entrySet()) {
String apiKey = entry.getKey();
Key value = entry.getValue();
ResourceDescription resource = toResource(apiKey);
if (resourceService.hasResource(resource)) {
resource = generateApiKey();
apiKey = resource.getName();
}
value.setKey(apiKey);
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setOriginalKey(value);
keys.put(apiKey, apiKeyData);
}
value.setKey(key);
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setOriginalKey(value);
keys.put(key, apiKeyData);
return null;
});
}

public void updateApiKeyData(ApiKeyData apiKeyData) {
String apiKey = apiKeyData.getPerRequestKey();
if (apiKey == null) {
return;
}
String json = ProxyUtil.convertToString(apiKeyData);
ResourceDescription resource = toResource(apiKey);
resourceService.putResource(resource, json, true, false);
}

private String generateApiKey() {
private ResourceDescription generateApiKey() {
String apiKey = generateKey();
while (keys.containsKey(apiKey)) {
ResourceDescription resource = toResource(apiKey);
while (resourceService.hasResource(resource) || keys.containsKey(apiKey)) {
log.warn("duplicate API key is found. Trying to generate a new one");
apiKey = generateKey();
resource = toResource(apiKey);
}
return apiKey;
return resource;
}

private static ResourceDescription toResource(String apiKey) {
return ResourceDescription.fromDecoded(
ResourceType.API_KEY_DATA, API_KEY_DATA_BUCKET, API_KEY_DATA_LOCATION, apiKey);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ private ResourceFolderMetadata getFolderMetadata(ResourceDescription descriptor,
return new ResourceFolderMetadata(descriptor, resources).setNextToken(set.getNextMarker());
}

@Nullable
public ResourceItemMetadata getResourceMetadata(ResourceDescription descriptor) {
if (descriptor.isFolder()) {
throw new IllegalArgumentException("Resource folder: " + descriptor.getUrl());
Expand Down
8 changes: 4 additions & 4 deletions src/test/java/com/epam/aidial/core/FileApiTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public void testBucket(Vertx vertx, VertxTestContext context) {
public void testPerRequestBucket(Vertx vertx, VertxTestContext context) {
// creating per-request API key with proxyKey1 as originator
// and proxyKey2 caller
ApiKeyData projectApiKeyData = apiKeyStore.getApiKeyData("proxyKey1");
ApiKeyData projectApiKeyData = apiKeyStore.getApiKeyData("proxyKey1").result();
ApiKeyData apiKeyData2 = new ApiKeyData();
apiKeyData2.setOriginalKey(projectApiKeyData.getOriginalKey());

Expand Down Expand Up @@ -344,7 +344,7 @@ public void testFileUpload(Vertx vertx, VertxTestContext context) {
public void testFileUploadToAppdata(Vertx vertx, VertxTestContext context) {
// creating per-request API key with proxyKey1 as originator
// and proxyKey2 caller
ApiKeyData projectApiKeyData = apiKeyStore.getApiKeyData("proxyKey1");
ApiKeyData projectApiKeyData = apiKeyStore.getApiKeyData("proxyKey1").result();
ApiKeyData apiKeyData2 = new ApiKeyData();
apiKeyData2.setOriginalKey(projectApiKeyData.getOriginalKey());
// set deployment ID for proxyKey2
Expand Down Expand Up @@ -436,7 +436,7 @@ public void testDownloadSharedFile(Vertx vertx, VertxTestContext context) {

// creating per-request API key with proxyKey2 as originator
// and proxyKey1 caller
ApiKeyData projectApiKeyData = apiKeyStore.getApiKeyData("proxyKey2");
ApiKeyData projectApiKeyData = apiKeyStore.getApiKeyData("proxyKey2").result();
ApiKeyData apiKeyData1 = new ApiKeyData();
apiKeyData1.setOriginalKey(projectApiKeyData.getOriginalKey());
// set deployment ID for proxyKey1
Expand Down Expand Up @@ -505,7 +505,7 @@ public void testDownloadFileWithinSharedFolder(Vertx vertx, VertxTestContext con

// creating per-request API key with proxyKey2 as originator
// and proxyKey1 caller
ApiKeyData projectApiKeyData = apiKeyStore.getApiKeyData("proxyKey2");
ApiKeyData projectApiKeyData = apiKeyStore.getApiKeyData("proxyKey2").result();
ApiKeyData apiKeyData1 = new ApiKeyData();
apiKeyData1.setOriginalKey(projectApiKeyData.getOriginalKey());
// set deployment ID for proxyKey1
Expand Down
Loading
Loading