Skip to content

Commit

Permalink
fix: Core rejects requests if trace id is not found #110 (#111)
Browse files Browse the repository at this point in the history
Co-authored-by: Aliaksandr Stsiapanay <[email protected]>
  • Loading branch information
astsiapanay and astsiapanay authored Jan 5, 2024
1 parent 6c2cb9b commit e02bab8
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 244 deletions.
8 changes: 5 additions & 3 deletions src/main/java/com/epam/aidial/core/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import com.epam.aidial.core.upstream.UpstreamBalancer;
import com.epam.aidial.core.util.HttpStatus;
import com.epam.aidial.core.util.ProxyUtil;
import io.opentelemetry.api.trace.Span;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;
Expand Down Expand Up @@ -115,6 +116,7 @@ private void handleRequest(HttpServerRequest request) {
Config config = configStore.load();
String apiKey = request.headers().get(HEADER_API_KEY);
String authorization = request.getHeader(HttpHeaders.AUTHORIZATION);
String traceId = Span.current().getSpanContext().getTraceId();
log.debug("Authorization header: {}", authorization);
Key key;
if (apiKey == null && authorization == null) {
Expand Down Expand Up @@ -143,7 +145,7 @@ private void handleRequest(HttpServerRequest request) {
extractedClaims.onComplete(result -> {
try {
if (result.succeeded()) {
onExtractClaimsSuccess(result.result(), config, request, key);
onExtractClaimsSuccess(result.result(), config, request, key, traceId);
} else {
onExtractClaimsFailure(result.cause(), request);
}
Expand All @@ -161,8 +163,8 @@ private void onExtractClaimsFailure(Throwable error, HttpServerRequest request)
}

private void onExtractClaimsSuccess(ExtractedClaims extractedClaims, Config config,
HttpServerRequest request, Key key) throws Exception {
ProxyContext context = new ProxyContext(config, request, key, extractedClaims);
HttpServerRequest request, Key key, String traceId) throws Exception {
ProxyContext context = new ProxyContext(config, request, key, extractedClaims, traceId);
Controller controller = ControllerSelector.select(this, context);
controller.handle();
}
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/epam/aidial/core/ProxyContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class ProxyContext {
private final Key key;
private final HttpServerRequest request;
private final HttpServerResponse response;
private final String traceId;

private Deployment deployment;
private String userSub;
Expand All @@ -52,7 +53,7 @@ public class ProxyContext {
// the project belongs to API key which initiated request
private String originalProject;

public ProxyContext(Config config, HttpServerRequest request, Key key, ExtractedClaims extractedClaims) {
public ProxyContext(Config config, HttpServerRequest request, Key key, ExtractedClaims extractedClaims, String traceId) {
this.config = config;
this.key = key;
if (key != null) {
Expand All @@ -67,6 +68,7 @@ public ProxyContext(Config config, HttpServerRequest request, Key key, Extracted
this.userHash = extractedClaims.userHash();
this.userSub = extractedClaims.sub();
}
this.traceId = traceId;
}

public Future<Void> respond(HttpStatus status) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public class DeploymentPostController {

private final Proxy proxy;
private final ProxyContext context;
private boolean unregisterTrace;

public Future<?> handle(String deploymentId, String deploymentApi) {
String contentType = context.getRequest().getHeader(HttpHeaders.CONTENT_TYPE);
Expand All @@ -79,10 +80,7 @@ public Future<?> handle(String deploymentId, String deploymentApi) {
return context.respond(HttpStatus.FORBIDDEN, "Forbidden deployment");
}

if (!proxy.getRateLimiter().register(context)) {
log.warn("Trace is not found by id for request: method={}, uri={}", context.getRequest().method(), context.getRequest().uri());
return respond(HttpStatus.BAD_REQUEST, "Trace is not found");
}
unregisterTrace = !proxy.getRateLimiter().register(context);

context.setDeployment(deployment);

Expand Down Expand Up @@ -247,7 +245,7 @@ private void handleProxyResponse(HttpClientResponse proxyResponse) {
.to(response)
.onSuccess(ignored -> handleResponse())
.onFailure(this::handleResponseError)
.onComplete(ignore -> proxy.getRateLimiter().unregister());
.onComplete(ignore -> unregister());
}

/**
Expand Down Expand Up @@ -463,17 +461,23 @@ private static boolean isBaseAssistant(Deployment deployment) {
}

private Future<Void> respond(HttpStatus status, String errorMessage) {
proxy.getRateLimiter().unregister();
unregister();
return context.respond(status, errorMessage);
}

private Future<Void> respond(HttpStatus status) {
proxy.getRateLimiter().unregister();
unregister();
return context.respond(status);
}

private Future<Void> respond(HttpStatus status, Object result) {
proxy.getRateLimiter().unregister();
unregister();
return context.respond(status, result);
}

private void unregister() {
if (unregisterTrace) {
proxy.getRateLimiter().unregister(context);
}
}
}
43 changes: 19 additions & 24 deletions src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import com.epam.aidial.core.token.TokenUsage;
import com.epam.aidial.core.util.HttpStatus;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.sdk.trace.ReadableSpan;
import lombok.extern.slf4j.Slf4j;

import java.util.List;
Expand All @@ -20,7 +19,7 @@ public class RateLimiter {
private final ConcurrentHashMap<Id, RateLimit> rates = new ConcurrentHashMap<>();

public void increase(ProxyContext context) {
Entity entity = getEntityFromTracingContext();
Entity entity = getEntityFromTracingContext(context);
if (entity == null || entity.user()) {
return;
}
Expand All @@ -39,7 +38,7 @@ public void increase(ProxyContext context) {
}

public RateLimitResult limit(ProxyContext context) {
Entity entity = getEntityFromTracingContext();
Entity entity = getEntityFromTracingContext(context);
if (entity == null) {
Span span = Span.current();
log.warn("Entity is not found by traceId={}", span.getSpanContext().getTraceId());
Expand Down Expand Up @@ -75,36 +74,33 @@ public RateLimitResult limit(ProxyContext context) {
return rate.update(timestamp, limit);
}

/**
* Returns <code>true</code> if the trace is already registered otherwise <code>false</code>.
*/
public boolean register(ProxyContext context) {
ReadableSpan span = (ReadableSpan) Span.current();
String traceId = span.getSpanContext().getTraceId();
if (span.getParentSpanContext().isRemote()) {
Entity entity = traceIdToEntity.get(traceId);
if (entity != null) {
if (entity.user()) {
context.setUserHash(entity.name());
} else {
context.setOriginalProject(entity.name());
}
String traceId = context.getTraceId();
Entity entity = traceIdToEntity.get(traceId);
if (entity != null) {
// update context with the original requester
if (entity.user()) {
context.setUserHash(entity.name());
} else {
context.setOriginalProject(entity.name());
}
return entity != null;
} else {
if (context.getKey() != null) {
Key key = context.getKey();
traceIdToEntity.put(traceId, new Entity(key.getKey(), List.of(key.getRole()), key.getProject(), false));
} else {
traceIdToEntity.put(traceId, new Entity(context.getUserSub(), context.getUserRoles(), context.getUserHash(), true));
}
return true;
}
return entity != null;
}

public void unregister() {
ReadableSpan span = (ReadableSpan) Span.current();
if (!span.getParentSpanContext().isRemote()) {
String traceId = span.getSpanContext().getTraceId();
traceIdToEntity.remove(traceId);
}
public void unregister(ProxyContext context) {
String traceId = context.getTraceId();
traceIdToEntity.remove(traceId);
}

private Limit getLimitByApiKey(ProxyContext context, Entity entity) {
Expand All @@ -120,9 +116,8 @@ private Limit getLimitByApiKey(ProxyContext context, Entity entity) {
return role.getLimits().get(deployment.getName());
}

protected Entity getEntityFromTracingContext() {
ReadableSpan span = (ReadableSpan) Span.current();
String traceId = span.getSpanContext().getTraceId();
private Entity getEntityFromTracingContext(ProxyContext context) {
String traceId = context.getTraceId();
return traceIdToEntity.get(traceId);
}

Expand Down
67 changes: 0 additions & 67 deletions src/test/java/com/epam/aidial/core/TracerApiTest.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ public void beforeEach() {
@Test
public void testUnsupportedContentType() {
when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn("unsupported");
when(proxy.getRateLimiter()).thenReturn(rateLimiter);

controller.handle("app1", "api");

Expand Down Expand Up @@ -112,23 +111,6 @@ public void testDeploymentNotFound() {
verify(context).respond(eq(NOT_FOUND), anyString());
}

@Test
public void testTraceNotFound() {
when(proxy.getRateLimiter()).thenReturn(rateLimiter);
when(rateLimiter.register(any(ProxyContext.class))).thenReturn(false);
when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn(HEADER_CONTENT_TYPE_APPLICATION_JSON);
Config config = new Config();
config.setApplications(new HashMap<>());
Application application = new Application();
application.setName("app1");
config.getApplications().put("app1", application);
when(context.getConfig()).thenReturn(config);

controller.handle("app1", "chat/completions");

verify(context).respond(eq(BAD_REQUEST), anyString());
}

@Test
public void testNoRoute() {
when(proxy.getRateLimiter()).thenReturn(rateLimiter);
Expand Down
Loading

0 comments on commit e02bab8

Please sign in to comment.