Skip to content

Commit

Permalink
feat: Support distributed tracing #79 (#99)
Browse files Browse the repository at this point in the history
Co-authored-by: Aliaksandr Stsiapanay <[email protected]>
  • Loading branch information
astsiapanay and astsiapanay authored Jan 4, 2024
1 parent 2e06dbb commit f6a9403
Show file tree
Hide file tree
Showing 6 changed files with 448 additions and 31 deletions.
18 changes: 18 additions & 0 deletions src/main/java/com/epam/aidial/core/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
import com.epam.deltix.gflog.core.LogConfigurator;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.registry.otlp.OtlpMeterRegistry;
import io.opentelemetry.api.OpenTelemetry;
import io.opentelemetry.api.trace.propagation.W3CTraceContextPropagator;
import io.opentelemetry.context.propagation.ContextPropagators;
import io.opentelemetry.sdk.OpenTelemetrySdk;
import io.opentelemetry.sdk.trace.SdkTracerProvider;
import io.vertx.config.spi.utils.JsonObjectHelper;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
Expand All @@ -26,6 +31,7 @@
import io.vertx.core.json.JsonObject;
import io.vertx.core.metrics.MetricsOptions;
import io.vertx.micrometer.MicrometerMetricsOptions;
import io.vertx.tracing.opentelemetry.OpenTelemetryOptions;
import lombok.extern.slf4j.Slf4j;

import java.io.Closeable;
Expand All @@ -51,11 +57,13 @@ public class AiDial {

@VisibleForTesting
void start() throws Exception {
System.setProperty("io.opentelemetry.context.contextStorageProvider", "io.vertx.tracing.opentelemetry.VertxContextStorageProvider");
try {
settings = settings();

VertxOptions vertxOptions = new VertxOptions(settings("vertx"));
setupMetrics(vertxOptions);
setupTracing(vertxOptions);

vertx = Vertx.vertx(vertxOptions);
client = vertx.createHttpClient(new HttpClientOptions(settings("client")));
Expand Down Expand Up @@ -211,4 +219,14 @@ private static void setupMetrics(VertxOptions options) {

options.setMetricsOptions(micrometer);
}

private static void setupTracing(VertxOptions vertxOptions) {
SdkTracerProvider sdkTracerProvider = SdkTracerProvider.builder().build();
OpenTelemetry openTelemetry = OpenTelemetrySdk.builder()
.setTracerProvider(sdkTracerProvider)
.setPropagators(ContextPropagators.create(W3CTraceContextPropagator.getInstance()))
.build();

vertxOptions.setTracingOptions(new OpenTelemetryOptions(openTelemetry));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public class DeploymentPostController {
public Future<?> handle(String deploymentId, String deploymentApi) {
String contentType = context.getRequest().getHeader(HttpHeaders.CONTENT_TYPE);
if (!StringUtils.containsIgnoreCase(contentType, Proxy.HEADER_CONTENT_TYPE_APPLICATION_JSON)) {
return context.respond(HttpStatus.UNSUPPORTED_MEDIA_TYPE, "Only application/json is supported");
return respond(HttpStatus.UNSUPPORTED_MEDIA_TYPE, "Only application/json is supported");
}

Deployment deployment = context.getConfig().selectDeployment(deploymentId);
Expand All @@ -79,6 +79,11 @@ public Future<?> handle(String deploymentId, String deploymentApi) {
return context.respond(HttpStatus.FORBIDDEN, "Forbidden deployment");
}

if (!proxy.getRateLimiter().register(context)) {
log.warn("Trace is not found by id for request: method={}, uri={}", context.getRequest().method(), context.getRequest().uri());
return respond(HttpStatus.BAD_REQUEST, "Trace is not found");
}

context.setDeployment(deployment);

RateLimitResult rateLimitResult;
Expand All @@ -88,7 +93,7 @@ public Future<?> handle(String deploymentId, String deploymentApi) {
rateLimitError.getError().setCode(String.valueOf(rateLimitResult.status().getCode()));
rateLimitError.getError().setMessage(rateLimitResult.errorMessage());
log.error("Rate limit error {}. Key: {}. User sub: {}", rateLimitResult.errorMessage(), context.getProject(), context.getUserSub());
return context.respond(rateLimitResult.status(), rateLimitError);
return respond(rateLimitResult.status(), rateLimitError);
}

log.info("Received request from client. Key: {}. Deployment: {}. Headers: {}", context.getProject(),
Expand All @@ -100,7 +105,7 @@ public Future<?> handle(String deploymentId, String deploymentApi) {

if (!endpointRoute.hasNext()) {
log.error("No route. Key: {}. Deployment: {}. User sub: {}", context.getProject(), deploymentId, context.getUserSub());
return context.respond(HttpStatus.BAD_GATEWAY, "No route");
return respond(HttpStatus.BAD_GATEWAY, "No route");
}

return context.getRequest().body()
Expand All @@ -115,7 +120,7 @@ private Future<?> sendRequest() {

if (!route.hasNext()) {
log.error("No route. Key: {}. Deployment: {}. User sub: {}", context.getProject(), context.getDeployment().getName(), context.getUserSub());
return context.respond(HttpStatus.BAD_GATEWAY, "No route");
return respond(HttpStatus.BAD_GATEWAY, "No route");
}

Upstream upstream = route.next();
Expand Down Expand Up @@ -146,11 +151,11 @@ void handleRequestBody(Buffer requestBody) {
context.setRequestBody(enhancedRequest.getKey());
context.setRequestHeaders(enhancedRequest.getValue());
} catch (HttpException e) {
context.respond(e.getStatus(), e.getMessage());
respond(e.getStatus(), e.getMessage());
log.warn("Can't enhance assistant request: {}", e.getMessage());
return;
} catch (Throwable e) {
context.respond(HttpStatus.BAD_REQUEST);
respond(HttpStatus.BAD_REQUEST);
log.warn("Can't enhance assistant request: {}", e.getMessage());
return;
}
Expand Down Expand Up @@ -241,7 +246,8 @@ private void handleProxyResponse(HttpClientResponse proxyResponse) {
.endOnFailure(false)
.to(response)
.onSuccess(ignored -> handleResponse())
.onFailure(this::handleResponseError);
.onFailure(this::handleResponseError)
.onComplete(ignore -> proxy.getRateLimiter().unregister());
}

/**
Expand Down Expand Up @@ -288,7 +294,7 @@ private void handleResponse() {
*/
private void handleRequestBodyError(Throwable error) {
log.warn("Failed to receive client body: {}", error.getMessage());
context.respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body");
respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body");
}

/**
Expand All @@ -300,7 +306,7 @@ private void handleProxyConnectionError(Throwable error) {
String uri = buildUri(context);
log.warn("Can't connect to origin. Key: {}. Deployment: {}. Address: {}: {}", projectName,
deploymentName, uri, error.getMessage());
context.respond(HttpStatus.BAD_GATEWAY, "Failed to connect to origin");
respond(HttpStatus.BAD_GATEWAY, "Failed to connect to origin");
}

/**
Expand All @@ -312,7 +318,7 @@ private void handleProxyResponseError(Throwable error) {
SocketAddress proxyAddress = context.getProxyRequest().connection().remoteAddress();
log.warn("Proxy received response error from origin. Key: {}. Deployment: {}. Address: {}: {}", projectName,
deploymentName, proxyAddress, error.getMessage());
context.respond(HttpStatus.BAD_GATEWAY, "Received error response from origin");
respond(HttpStatus.BAD_GATEWAY, "Received error response from origin");
}

/**
Expand Down Expand Up @@ -431,8 +437,7 @@ private static Buffer enhanceModelRequest(ProxyContext context) throws Exception
tree.remove("model");
tree.put("model", overrideName);

Buffer updatedBody = Buffer.buffer(ProxyUtil.MAPPER.writeValueAsBytes(tree));
return updatedBody;
return Buffer.buffer(ProxyUtil.MAPPER.writeValueAsBytes(tree));
}
}

Expand All @@ -456,4 +461,19 @@ private static void deletePrompt(ArrayNode messages) {
private static boolean isBaseAssistant(Deployment deployment) {
return deployment.getName().equals(Config.ASSISTANT);
}

private Future<Void> respond(HttpStatus status, String errorMessage) {
proxy.getRateLimiter().unregister();
return context.respond(status, errorMessage);
}

private Future<Void> respond(HttpStatus status) {
proxy.getRateLimiter().unregister();
return context.respond(status);
}

private Future<Void> respond(HttpStatus status, Object result) {
proxy.getRateLimiter().unregister();
return context.respond(status, result);
}
}
92 changes: 73 additions & 19 deletions src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,21 @@
import com.epam.aidial.core.config.Role;
import com.epam.aidial.core.token.TokenUsage;
import com.epam.aidial.core.util.HttpStatus;
import lombok.RequiredArgsConstructor;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.sdk.trace.ReadableSpan;
import lombok.extern.slf4j.Slf4j;

import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
@RequiredArgsConstructor
public class RateLimiter {

private final ConcurrentHashMap<String, Entity> traceIdToEntity = new ConcurrentHashMap<>();
private final ConcurrentHashMap<Id, RateLimit> rates = new ConcurrentHashMap<>();

public void increase(ProxyContext context) {
Key key = context.getKey();
if (key == null || key.getRole() == null) {
Entity entity = getEntityFromTracingContext();
if (entity == null || entity.user()) {
return;
}
Deployment deployment = context.getDeployment();
Expand All @@ -30,27 +31,29 @@ public void increase(ProxyContext context) {
return;
}

Id id = new Id(key.getKey(), deployment.getName());
Id id = new Id(entity.id(), deployment.getName(), entity.user());
RateLimit rate = rates.computeIfAbsent(id, k -> new RateLimit());

long timestamp = System.currentTimeMillis();
rate.add(timestamp, usage.getTotalTokens());
}

public RateLimitResult limit(ProxyContext context) {
Key key = context.getKey();
if (key == null || key.getRole() == null) {
return RateLimitResult.SUCCESS;
}
Role role = context.getConfig().getRoles().get(key.getRole());

if (role == null) {
log.warn("Role is not found for key: {}", context.getKey().getKey());
Entity entity = getEntityFromTracingContext();
if (entity == null) {
Span span = Span.current();
log.warn("Entity is not found by traceId={}", span.getSpanContext().getTraceId());
return new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied");
}
Limit limit;
if (entity.user()) {
// don't support user limits yet
return RateLimitResult.SUCCESS;
} else {
limit = getLimitByApiKey(context, entity);
}

Deployment deployment = context.getDeployment();
Limit limit = role.getLimits().get(deployment.getName());

if (limit == null || !limit.isPositive()) {
if (limit == null) {
Expand All @@ -61,7 +64,7 @@ public RateLimitResult limit(ProxyContext context) {
return new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied");
}

Id id = new Id(key.getKey(), deployment.getName());
Id id = new Id(entity.id(), deployment.getName(), entity.user());
RateLimit rate = rates.get(id);

if (rate == null) {
Expand All @@ -72,10 +75,61 @@ public RateLimitResult limit(ProxyContext context) {
return rate.update(timestamp, limit);
}

private record Id(String key, String resource) {
public boolean register(ProxyContext context) {
ReadableSpan span = (ReadableSpan) Span.current();
String traceId = span.getSpanContext().getTraceId();
if (span.getParentSpanContext().isRemote()) {
return traceIdToEntity.containsKey(traceId);
} else {
if (context.getKey() != null) {
Key key = context.getKey();
traceIdToEntity.put(traceId, new Entity(key.getKey(), List.of(key.getRole()), false));
} else {
traceIdToEntity.put(traceId, new Entity(context.getUserSub(), context.getUserRoles(), true));
}
return true;
}
}

public void unregister() {
ReadableSpan span = (ReadableSpan) Span.current();
if (!span.getParentSpanContext().isRemote()) {
String traceId = span.getSpanContext().getTraceId();
traceIdToEntity.remove(traceId);
}
}

private Limit getLimitByApiKey(ProxyContext context, Entity entity) {
// API key has always one role
Role role = context.getConfig().getRoles().get(entity.roles.get(0));

if (role == null) {
log.warn("Role is not found for key: {}", context.getKey().getKey());
return null;
}

Deployment deployment = context.getDeployment();
return role.getLimits().get(deployment.getName());
}

protected Entity getEntityFromTracingContext() {
ReadableSpan span = (ReadableSpan) Span.current();
String traceId = span.getSpanContext().getTraceId();
return traceIdToEntity.get(traceId);
}

private record Id(String key, String resource, boolean user) {
@Override
public String toString() {
return String.format("key: %s, resource: %s, user: %b", key, resource, user);
}
}

private record Entity(String id, List<String> roles, boolean user) {
@Override
public String toString() {
return String.format("key: %s, resource: %s", key, resource);
return String.format("Entity: %s, resource: %s, user: %b", id, roles, user);
}
}
}

}
67 changes: 67 additions & 0 deletions src/test/java/com/epam/aidial/core/TracerApiTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package com.epam.aidial.core;

import io.vertx.core.Vertx;
import io.vertx.ext.web.client.WebClient;
import io.vertx.junit5.VertxExtension;
import io.vertx.junit5.VertxTestContext;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

import java.nio.file.Path;

import static org.junit.jupiter.api.Assertions.assertEquals;

@ExtendWith(VertxExtension.class)
public class TracerApiTest {
private static AiDial dial;
private static int serverPort;
private static Path testDir;

@BeforeAll
public static void init() throws Exception {
// initialize server
dial = new AiDial();
testDir = FileUtil.baseTestPath(FileApiTest.class);
dial.setStorage(FileUtil.buildFsBlobStorage(testDir));
dial.start();
serverPort = dial.getServer().actualPort();
}

@BeforeEach
public void setUp() {
// prepare test directory
FileUtil.createDir(testDir.resolve("test"));
}

@AfterEach
public void clean() {
// clean test directory
FileUtil.deleteDir(testDir);
}

@AfterAll
public static void destroy() {
// stop server
dial.stop();
}

@Test
public void testTraceNotFound(Vertx vertx, VertxTestContext context) {
WebClient client = WebClient.create(vertx);
client.post(serverPort, "localhost", "/openai/deployments/app/chat/completions")
.putHeader("Api-key", "proxyKey2")
.putHeader("content-type", "application/json")
.putHeader("traceparent", "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
.send(context.succeeding(response -> {
context.verify(() -> {
assertEquals(400, response.statusCode());
context.completeNow();
});
}));
}

}
Loading

0 comments on commit f6a9403

Please sign in to comment.