From 202817829bf68d8e7fde2f3131f555cd5aca7235 Mon Sep 17 00:00:00 2001 From: Aliaksandr Stsiapanay Date: Mon, 8 Jan 2024 13:27:21 +0300 Subject: [PATCH] feat: Calculate token usage statistics using call stack #117 --- src/main/java/com/epam/aidial/core/Proxy.java | 8 +- .../com/epam/aidial/core/ProxyContext.java | 20 +++- .../controller/DeploymentPostController.java | 4 +- .../epam/aidial/core/limiter/RateLimiter.java | 109 +++++++++++++++--- .../com/epam/aidial/core/log/GfLogStore.java | 16 ++- .../epam/aidial/core/token/TokenUsage.java | 9 ++ .../aidial/core/limiter/RateLimiterTest.java | 47 ++++++-- 7 files changed, 179 insertions(+), 34 deletions(-) diff --git a/src/main/java/com/epam/aidial/core/Proxy.java b/src/main/java/com/epam/aidial/core/Proxy.java index 0f7582a2d..33cab2c5e 100644 --- a/src/main/java/com/epam/aidial/core/Proxy.java +++ b/src/main/java/com/epam/aidial/core/Proxy.java @@ -116,7 +116,7 @@ private void handleRequest(HttpServerRequest request) { Config config = configStore.load(); String apiKey = request.headers().get(HEADER_API_KEY); String authorization = request.getHeader(HttpHeaders.AUTHORIZATION); - String traceId = Span.current().getSpanContext().getTraceId(); + Span currentSpan = Span.current(); log.debug("Authorization header: {}", authorization); Key key; if (apiKey == null && authorization == null) { @@ -145,7 +145,7 @@ private void handleRequest(HttpServerRequest request) { extractedClaims.onComplete(result -> { try { if (result.succeeded()) { - onExtractClaimsSuccess(result.result(), config, request, key, traceId); + onExtractClaimsSuccess(result.result(), config, request, key, currentSpan); } else { onExtractClaimsFailure(result.cause(), request); } @@ -163,8 +163,8 @@ private void onExtractClaimsFailure(Throwable error, HttpServerRequest request) } private void onExtractClaimsSuccess(ExtractedClaims extractedClaims, Config config, - HttpServerRequest request, Key key, String traceId) throws Exception { - ProxyContext context = new ProxyContext(config, request, key, extractedClaims, traceId); + HttpServerRequest request, Key key, Span span) throws Exception { + ProxyContext context = new ProxyContext(config, request, key, extractedClaims, span); Controller controller = ControllerSelector.select(this, context); controller.handle(); } diff --git a/src/main/java/com/epam/aidial/core/ProxyContext.java b/src/main/java/com/epam/aidial/core/ProxyContext.java index eed75d1c0..6a6b1cbbe 100644 --- a/src/main/java/com/epam/aidial/core/ProxyContext.java +++ b/src/main/java/com/epam/aidial/core/ProxyContext.java @@ -9,6 +9,8 @@ import com.epam.aidial.core.util.BufferingReadStream; import com.epam.aidial.core.util.HttpStatus; import com.epam.aidial.core.util.ProxyUtil; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.sdk.trace.ReadableSpan; import io.vertx.core.Future; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpClientRequest; @@ -31,7 +33,7 @@ public class ProxyContext { private final Key key; private final HttpServerRequest request; private final HttpServerResponse response; - private final String traceId; + private final Span span; private Deployment deployment; private String userSub; @@ -53,7 +55,7 @@ public class ProxyContext { // the project belongs to API key which initiated request private String originalProject; - public ProxyContext(Config config, HttpServerRequest request, Key key, ExtractedClaims extractedClaims, String traceId) { + public ProxyContext(Config config, HttpServerRequest request, Key key, ExtractedClaims extractedClaims, Span span) { this.config = config; this.key = key; if (key != null) { @@ -68,7 +70,7 @@ public ProxyContext(Config config, HttpServerRequest request, Key key, Extracted this.userHash = extractedClaims.userHash(); this.userSub = extractedClaims.sub(); } - this.traceId = traceId; + this.span = span; } public Future respond(HttpStatus status) { @@ -90,4 +92,16 @@ public Future respond(HttpStatus status, String body) { public String getProject() { return key == null ? null : key.getProject(); } + + public String getTraceId() { + return span.getSpanContext().getTraceId(); + } + + public String getCurrentSpanId() { + return span.getSpanContext().getSpanId(); + } + + public String getParentSpanId() { + return ((ReadableSpan) span).getParentSpanContext().getSpanId(); + } } \ No newline at end of file diff --git a/src/main/java/com/epam/aidial/core/controller/DeploymentPostController.java b/src/main/java/com/epam/aidial/core/controller/DeploymentPostController.java index 22b7e1dce..bbcc51397 100644 --- a/src/main/java/com/epam/aidial/core/controller/DeploymentPostController.java +++ b/src/main/java/com/epam/aidial/core/controller/DeploymentPostController.java @@ -255,7 +255,6 @@ private void handleResponse() { Buffer responseBody = context.getResponseStream().getContent(); context.setResponseBody(responseBody); context.setResponseBodyTimestamp(System.currentTimeMillis()); - proxy.getLogStore().save(context); if (context.getDeployment() instanceof Model && context.getResponse().getStatusCode() == HttpStatus.OK.getCode()) { TokenUsage tokenUsage = TokenUsageParser.parse(responseBody); @@ -272,6 +271,9 @@ private void handleResponse() { } } + proxy.getRateLimiter().calculateTokenUsage(context); + proxy.getLogStore().save(context); + log.info("Sent response to client. Key: {}. Deployment: {}. Endpoint: {}. Upstream: {}. Status: {}. Length: {}." + " Timing: {} (body={}, connect={}, header={}, body={}). Tokens: {}", context.getProject(), context.getDeployment().getName(), diff --git a/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java b/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java index 37ed4bf95..276fcb0ea 100644 --- a/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java +++ b/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java @@ -7,11 +7,14 @@ import com.epam.aidial.core.config.Role; import com.epam.aidial.core.token.TokenUsage; import com.epam.aidial.core.util.HttpStatus; -import io.opentelemetry.api.trace.Span; +import lombok.Data; import lombok.extern.slf4j.Slf4j; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @Slf4j @@ -21,7 +24,7 @@ public class RateLimiter { public void increase(ProxyContext context) { Entity entity = getEntityFromTracingContext(context); - if (entity == null || entity.user()) { + if (entity == null || entity.isUser()) { return; } Deployment deployment = context.getDeployment(); @@ -31,7 +34,9 @@ public void increase(ProxyContext context) { return; } - Id id = new Id(entity.id(), deployment.getName(), entity.user()); + entity.setTokeUsage(context.getCurrentSpanId(), usage); + + Id id = new Id(entity.getId(), deployment.getName(), entity.isUser()); RateLimit rate = rates.computeIfAbsent(id, k -> new RateLimit()); long timestamp = System.currentTimeMillis(); @@ -41,12 +46,11 @@ public void increase(ProxyContext context) { public RateLimitResult limit(ProxyContext context) { Entity entity = getEntityFromTracingContext(context); if (entity == null) { - Span span = Span.current(); - log.warn("Entity is not found by traceId={}", span.getSpanContext().getTraceId()); + log.warn("Entity is not found by traceId={}", context.getTraceId()); return new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied"); } Limit limit; - if (entity.user()) { + if (entity.isUser()) { // don't support user limits yet return RateLimitResult.SUCCESS; } else { @@ -64,7 +68,7 @@ public RateLimitResult limit(ProxyContext context) { return new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied"); } - Id id = new Id(entity.id(), deployment.getName(), entity.user()); + Id id = new Id(entity.getId(), deployment.getName(), entity.isUser()); RateLimit rate = rates.get(id); if (rate == null) { @@ -81,22 +85,36 @@ public RateLimitResult limit(ProxyContext context) { public boolean register(ProxyContext context) { String traceId = context.getTraceId(); Entity entity = traceIdToEntity.get(traceId); - if (entity != null) { + boolean result = entity != null; + if (result) { + entity.link(context.getParentSpanId(), context.getCurrentSpanId()); // update context with the original requester - if (entity.user()) { - context.setUserHash(entity.name()); + if (entity.isUser()) { + context.setUserHash(entity.getName()); } else { - context.setOriginalProject(entity.name()); + context.setOriginalProject(entity.getName()); } } else { if (context.getKey() != null) { Key key = context.getKey(); - traceIdToEntity.put(traceId, new Entity(key.getKey(), Collections.singletonList(key.getRole()), key.getProject(), false)); + entity = new Entity(key.getKey(), Collections.singletonList(key.getRole()), key.getProject(), false); + traceIdToEntity.put(traceId, entity); } else { - traceIdToEntity.put(traceId, new Entity(context.getUserSub(), context.getUserRoles(), context.getUserHash(), true)); + entity = new Entity(context.getUserSub(), context.getUserRoles(), context.getUserHash(), true); + traceIdToEntity.put(traceId, entity); } } - return entity != null; + entity.addCall(context.getCurrentSpanId()); + return result; + } + + public void calculateTokenUsage(ProxyContext context) { + Entity entity = getEntityFromTracingContext(context); + if (entity == null) { + log.warn("Entity is not found by traceId={}", context.getTraceId()); + return; + } + context.setTokenUsage(entity.calculate(context.getCurrentSpanId())); } public void unregister(ProxyContext context) { @@ -129,11 +147,72 @@ public String toString() { } } - private record Entity(String id, List roles, String name, boolean user) { + @Data + private static class Entity { + + private final String id; + private final List roles; + private final String name; + private final boolean user; + private final Map spanIdToCallInfo = new HashMap<>(); + + public Entity(String id, List roles, String name, boolean user) { + this.id = id; + this.roles = roles; + this.name = name; + this.user = user; + } + + public synchronized void addCall(String spanId) { + spanIdToCallInfo.put(spanId, new CallInfo()); + } + + public synchronized void link(String parentSpanId, String childSpanId) { + CallInfo callInfo = spanIdToCallInfo.get(parentSpanId); + if (callInfo == null) { + log.warn("Parent span is not found by id: {}", parentSpanId); + return; + } + callInfo.childSpanIds.add(childSpanId); + } + + public synchronized void setTokeUsage(String spanId, TokenUsage tokenUsage) { + CallInfo callInfo = spanIdToCallInfo.get(spanId); + if (callInfo == null) { + log.warn("Span is not found by id: {}", spanId); + return; + } + callInfo.setTokenUsage(tokenUsage); + } + + public synchronized TokenUsage calculate(String spanId) { + CallInfo callInfo = spanIdToCallInfo.get(spanId); + if (callInfo == null) { + log.warn("Span is not found by id: {}", spanId); + return null; + } + TokenUsage tokenUsage = callInfo.tokenUsage; + for (String childSpanId : callInfo.childSpanIds) { + CallInfo childCall = spanIdToCallInfo.get(childSpanId); + if (childCall == null) { + log.warn("Child span is not found by id: {}", childSpanId); + continue; + } + tokenUsage.increase(childCall.tokenUsage); + } + return tokenUsage; + } + @Override public String toString() { return String.format("Entity: %s, resource: %s, user: %b", id, roles, user); } } + @Data + private static class CallInfo { + TokenUsage tokenUsage = new TokenUsage(); + List childSpanIds = new ArrayList<>(); + } + } diff --git a/src/main/java/com/epam/aidial/core/log/GfLogStore.java b/src/main/java/com/epam/aidial/core/log/GfLogStore.java index ef4f8aade..9b7a79985 100644 --- a/src/main/java/com/epam/aidial/core/log/GfLogStore.java +++ b/src/main/java/com/epam/aidial/core/log/GfLogStore.java @@ -2,6 +2,7 @@ import com.epam.aidial.core.Proxy; import com.epam.aidial.core.ProxyContext; +import com.epam.aidial.core.token.TokenUsage; import com.epam.aidial.core.util.HttpStatus; import com.epam.deltix.gflog.api.Log; import com.epam.deltix.gflog.api.LogEntry; @@ -94,8 +95,19 @@ private void append(ProxyContext context, LogEntry entry) { append(entry, "\",\"body\":\"", false); append(entry, context.getResponseBody()); - - append(entry, "\"}}", false); + TokenUsage tokenUsage = context.getTokenUsage(); + if (tokenUsage != null) { + append(entry, "\"},\"tokenUsage\":{", false); + append(entry, "\"completion_tokens\":", false); + append(entry, Long.toString(tokenUsage.getCompletionTokens()), true); + append(entry, ",\"prompt_tokens\":", false); + append(entry, Long.toString(tokenUsage.getPromptTokens()), true); + append(entry, ",\"total_tokens\":", false); + append(entry, Long.toString(tokenUsage.getTotalTokens()), true); + append(entry, "}}", false); + } else { + append(entry, "\"}}", false); + } } diff --git a/src/main/java/com/epam/aidial/core/token/TokenUsage.java b/src/main/java/com/epam/aidial/core/token/TokenUsage.java index f9e3c341e..26e7f72a2 100644 --- a/src/main/java/com/epam/aidial/core/token/TokenUsage.java +++ b/src/main/java/com/epam/aidial/core/token/TokenUsage.java @@ -8,6 +8,15 @@ public class TokenUsage { private long promptTokens; private long totalTokens; + public void increase(TokenUsage usage) { + if (usage == null) { + return; + } + completionTokens += usage.completionTokens; + promptTokens += usage.promptTokens; + totalTokens += usage.totalTokens; + } + @Override public String toString() { return "completion=" + completionTokens diff --git a/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java b/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java index 89c64a520..c7e819f29 100644 --- a/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java +++ b/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java @@ -9,6 +9,9 @@ import com.epam.aidial.core.security.ExtractedClaims; import com.epam.aidial.core.security.IdentityProvider; import com.epam.aidial.core.util.HttpStatus; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.sdk.trace.ReadableSpan; import io.vertx.core.http.HttpServerRequest; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -23,6 +26,8 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) public class RateLimiterTest { @@ -30,20 +35,28 @@ public class RateLimiterTest { @Mock private HttpServerRequest request; + @Mock + private TestSpan span; + + @Mock + private SpanContext spanContext; + private RateLimiter rateLimiter; @BeforeEach public void beforeEach() { rateLimiter = new RateLimiter(); + when(span.getSpanContext()).thenReturn(spanContext); } @Test public void testRegister_SuccessNoParentSpan() { + when(spanContext.getTraceId()).thenReturn("trace-id"); Key key = new Key(); key.setRole("role"); key.setKey("key"); key.setProject("project"); - ProxyContext proxyContext = new ProxyContext(new Config(), request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, "trace-id"); + ProxyContext proxyContext = new ProxyContext(new Config(), request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, span); assertFalse(rateLimiter.register(proxyContext)); assertEquals("project", proxyContext.getOriginalProject()); @@ -57,10 +70,11 @@ public void testRegister_SuccessNoParentSpan() { @Test public void testRegister_SuccessWithNullRole() { + when(spanContext.getTraceId()).thenReturn("trace-id"); Key key = new Key(); key.setKey("key"); key.setProject("project"); - ProxyContext proxyContext = new ProxyContext(new Config(), request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, "trace-id"); + ProxyContext proxyContext = new ProxyContext(new Config(), request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, span); assertFalse(rateLimiter.register(proxyContext)); assertEquals("project", proxyContext.getOriginalProject()); @@ -74,22 +88,28 @@ public void testRegister_SuccessWithNullRole() { @Test public void testRegister_SuccessParentSpanExists() { + when(spanContext.getTraceId()).thenReturn("trace-id"); Key key = new Key(); key.setRole("role"); key.setKey("key"); key.setProject("project"); - ProxyContext proxyContext = new ProxyContext(new Config(), request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, "trace-id"); + ProxyContext proxyContext = new ProxyContext(new Config(), request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, span); assertFalse(rateLimiter.register(proxyContext)); assertEquals("project", proxyContext.getOriginalProject()); + SpanContext parentSpanContext = mock(SpanContext.class); + when(parentSpanContext.getSpanId()).thenReturn("parent-id"); + when(span.getParentSpanContext()).thenReturn(parentSpanContext); + assertTrue(rateLimiter.register(proxyContext)); assertEquals("project", proxyContext.getOriginalProject()); } @Test public void testLimit_EntityNotFound() { - ProxyContext proxyContext = new ProxyContext(new Config(), request, new Key(), IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, "unknown-trace-id"); + when(spanContext.getTraceId()).thenReturn("unknown-trace-id"); + ProxyContext proxyContext = new ProxyContext(new Config(), request, new Key(), IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, span); RateLimitResult result = rateLimiter.limit(proxyContext); @@ -99,7 +119,8 @@ public void testLimit_EntityNotFound() { @Test public void testLimit_SuccessUser() { - ProxyContext proxyContext = new ProxyContext(new Config(), request, null, new ExtractedClaims("sub", Collections.emptyList(), "hash"), "trace-id"); + when(spanContext.getTraceId()).thenReturn("trace-id"); + ProxyContext proxyContext = new ProxyContext(new Config(), request, null, new ExtractedClaims("sub", Collections.emptyList(), "hash"), span); assertFalse(rateLimiter.register(proxyContext)); @@ -111,10 +132,11 @@ public void testLimit_SuccessUser() { @Test public void testLimit_ApiKeyLimitNotFound() { + when(spanContext.getTraceId()).thenReturn("trace-id"); Key key = new Key(); key.setRole("role"); key.setKey("key"); - ProxyContext proxyContext = new ProxyContext(new Config(), request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, "trace-id"); + ProxyContext proxyContext = new ProxyContext(new Config(), request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, span); proxyContext.setDeployment(new Model()); @@ -130,10 +152,11 @@ public void testLimit_ApiKeyLimitNotFound() { @Test public void testLimit_ApiKeyLimitNotFoundWithNullRole() { + when(spanContext.getTraceId()).thenReturn("trace-id"); Key key = new Key(); key.setKey("key"); key.setProject("project"); - ProxyContext proxyContext = new ProxyContext(new Config(), request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, "trace-id"); + ProxyContext proxyContext = new ProxyContext(new Config(), request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, span); proxyContext.setDeployment(new Model()); @@ -148,6 +171,7 @@ public void testLimit_ApiKeyLimitNotFoundWithNullRole() { @Test public void testLimit_ApiKeyLimitNegative() { + when(spanContext.getTraceId()).thenReturn("trace-id"); Key key = new Key(); key.setRole("role"); key.setKey("key"); @@ -157,7 +181,7 @@ public void testLimit_ApiKeyLimitNegative() { limit.setDay(-1); role.setLimits(Map.of("model", limit)); config.setRoles(Map.of("role", role)); - ProxyContext proxyContext = new ProxyContext(config, request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, "trace-id"); + ProxyContext proxyContext = new ProxyContext(config, request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, span); Model model = new Model(); model.setName("model"); proxyContext.setDeployment(model); @@ -174,6 +198,7 @@ public void testLimit_ApiKeyLimitNegative() { @Test public void testLimit_ApiKeySuccess() { + when(spanContext.getTraceId()).thenReturn("trace-id"); Key key = new Key(); key.setRole("role"); key.setKey("key"); @@ -182,7 +207,7 @@ public void testLimit_ApiKeySuccess() { Limit limit = new Limit(); role.setLimits(Map.of("model", limit)); config.setRoles(Map.of("role", role)); - ProxyContext proxyContext = new ProxyContext(config, request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, "trace-id"); + ProxyContext proxyContext = new ProxyContext(config, request, key, IdentityProvider.CLAIMS_WITH_EMPTY_ROLES, span); Model model = new Model(); model.setName("model"); proxyContext.setDeployment(model); @@ -195,4 +220,8 @@ public void testLimit_ApiKeySuccess() { assertEquals(HttpStatus.OK, result.status()); } + + private interface TestSpan extends Span, ReadableSpan { + + } }