Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Calculate token usage statistics using call stack #117 #118

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/main/java/com/epam/aidial/core/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ private void handleRequest(HttpServerRequest request) {
Config config = configStore.load();
String apiKey = request.headers().get(HEADER_API_KEY);
String authorization = request.getHeader(HttpHeaders.AUTHORIZATION);
String traceId = Span.current().getSpanContext().getTraceId();
Span currentSpan = Span.current();
log.debug("Authorization header: {}", authorization);
Key key;
if (apiKey == null && authorization == null) {
Expand Down Expand Up @@ -145,7 +145,7 @@ private void handleRequest(HttpServerRequest request) {
extractedClaims.onComplete(result -> {
try {
if (result.succeeded()) {
onExtractClaimsSuccess(result.result(), config, request, key, traceId);
onExtractClaimsSuccess(result.result(), config, request, key, currentSpan);
} else {
onExtractClaimsFailure(result.cause(), request);
}
Expand All @@ -163,8 +163,8 @@ private void onExtractClaimsFailure(Throwable error, HttpServerRequest request)
}

private void onExtractClaimsSuccess(ExtractedClaims extractedClaims, Config config,
HttpServerRequest request, Key key, String traceId) throws Exception {
ProxyContext context = new ProxyContext(config, request, key, extractedClaims, traceId);
HttpServerRequest request, Key key, Span span) throws Exception {
ProxyContext context = new ProxyContext(config, request, key, extractedClaims, span);
Controller controller = ControllerSelector.select(this, context);
controller.handle();
}
Expand Down
20 changes: 17 additions & 3 deletions src/main/java/com/epam/aidial/core/ProxyContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import com.epam.aidial.core.util.BufferingReadStream;
import com.epam.aidial.core.util.HttpStatus;
import com.epam.aidial.core.util.ProxyUtil;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.sdk.trace.ReadableSpan;
import io.vertx.core.Future;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpClientRequest;
Expand All @@ -31,7 +33,7 @@ public class ProxyContext {
private final Key key;
private final HttpServerRequest request;
private final HttpServerResponse response;
private final String traceId;
private final Span span;

private Deployment deployment;
private String userSub;
Expand All @@ -53,7 +55,7 @@ public class ProxyContext {
// the project belongs to API key which initiated request
private String originalProject;

public ProxyContext(Config config, HttpServerRequest request, Key key, ExtractedClaims extractedClaims, String traceId) {
public ProxyContext(Config config, HttpServerRequest request, Key key, ExtractedClaims extractedClaims, Span span) {
this.config = config;
this.key = key;
if (key != null) {
Expand All @@ -68,7 +70,7 @@ public ProxyContext(Config config, HttpServerRequest request, Key key, Extracted
this.userHash = extractedClaims.userHash();
this.userSub = extractedClaims.sub();
}
this.traceId = traceId;
this.span = span;
}

public Future<Void> respond(HttpStatus status) {
Expand All @@ -90,4 +92,16 @@ public Future<Void> respond(HttpStatus status, String body) {
public String getProject() {
return key == null ? null : key.getProject();
}

public String getTraceId() {
return span.getSpanContext().getTraceId();
}

public String getCurrentSpanId() {
return span.getSpanContext().getSpanId();
}

public String getParentSpanId() {
return ((ReadableSpan) span).getParentSpanContext().getSpanId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ private void handleResponse() {
Buffer responseBody = context.getResponseStream().getContent();
context.setResponseBody(responseBody);
context.setResponseBodyTimestamp(System.currentTimeMillis());
proxy.getLogStore().save(context);

if (context.getDeployment() instanceof Model && context.getResponse().getStatusCode() == HttpStatus.OK.getCode()) {
TokenUsage tokenUsage = TokenUsageParser.parse(responseBody);
Expand All @@ -272,6 +271,9 @@ private void handleResponse() {
}
}

proxy.getRateLimiter().calculateTokenUsage(context);
proxy.getLogStore().save(context);

log.info("Sent response to client. Key: {}. Deployment: {}. Endpoint: {}. Upstream: {}. Status: {}. Length: {}."
+ " Timing: {} (body={}, connect={}, header={}, body={}). Tokens: {}",
context.getProject(), context.getDeployment().getName(),
Expand Down
109 changes: 94 additions & 15 deletions src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
import com.epam.aidial.core.config.Role;
import com.epam.aidial.core.token.TokenUsage;
import com.epam.aidial.core.util.HttpStatus;
import io.opentelemetry.api.trace.Span;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
Expand All @@ -21,7 +24,7 @@ public class RateLimiter {

public void increase(ProxyContext context) {
Entity entity = getEntityFromTracingContext(context);
if (entity == null || entity.user()) {
if (entity == null || entity.isUser()) {
return;
}
Deployment deployment = context.getDeployment();
Expand All @@ -31,7 +34,9 @@ public void increase(ProxyContext context) {
return;
}

Id id = new Id(entity.id(), deployment.getName(), entity.user());
entity.setTokeUsage(context.getCurrentSpanId(), usage);

Id id = new Id(entity.getId(), deployment.getName(), entity.isUser());
RateLimit rate = rates.computeIfAbsent(id, k -> new RateLimit());

long timestamp = System.currentTimeMillis();
Expand All @@ -41,12 +46,11 @@ public void increase(ProxyContext context) {
public RateLimitResult limit(ProxyContext context) {
Entity entity = getEntityFromTracingContext(context);
if (entity == null) {
Span span = Span.current();
log.warn("Entity is not found by traceId={}", span.getSpanContext().getTraceId());
log.warn("Entity is not found by traceId={}", context.getTraceId());
return new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied");
}
Limit limit;
if (entity.user()) {
if (entity.isUser()) {
// don't support user limits yet
return RateLimitResult.SUCCESS;
} else {
Expand All @@ -64,7 +68,7 @@ public RateLimitResult limit(ProxyContext context) {
return new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied");
}

Id id = new Id(entity.id(), deployment.getName(), entity.user());
Id id = new Id(entity.getId(), deployment.getName(), entity.isUser());
RateLimit rate = rates.get(id);

if (rate == null) {
Expand All @@ -81,22 +85,36 @@ public RateLimitResult limit(ProxyContext context) {
public boolean register(ProxyContext context) {
String traceId = context.getTraceId();
Entity entity = traceIdToEntity.get(traceId);
if (entity != null) {
boolean result = entity != null;
if (result) {
entity.link(context.getParentSpanId(), context.getCurrentSpanId());
// update context with the original requester
if (entity.user()) {
context.setUserHash(entity.name());
if (entity.isUser()) {
context.setUserHash(entity.getName());
} else {
context.setOriginalProject(entity.name());
context.setOriginalProject(entity.getName());
}
} else {
if (context.getKey() != null) {
Key key = context.getKey();
traceIdToEntity.put(traceId, new Entity(key.getKey(), Collections.singletonList(key.getRole()), key.getProject(), false));
entity = new Entity(key.getKey(), Collections.singletonList(key.getRole()), key.getProject(), false);
traceIdToEntity.put(traceId, entity);
} else {
traceIdToEntity.put(traceId, new Entity(context.getUserSub(), context.getUserRoles(), context.getUserHash(), true));
entity = new Entity(context.getUserSub(), context.getUserRoles(), context.getUserHash(), true);
traceIdToEntity.put(traceId, entity);
}
}
return entity != null;
entity.addCall(context.getCurrentSpanId());
return result;
}

public void calculateTokenUsage(ProxyContext context) {
Entity entity = getEntityFromTracingContext(context);
if (entity == null) {
log.warn("Entity is not found by traceId={}", context.getTraceId());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add prefix to locate the problem more easily: "Failed to calculate token usage"

return;
}
context.setTokenUsage(entity.calculate(context.getCurrentSpanId()));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to do it only if it is missing? Some applications/models can compute it more precisely.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can rely on applications/assistants any longer to compute statistics

}

public void unregister(ProxyContext context) {
Expand Down Expand Up @@ -129,11 +147,72 @@ public String toString() {
}
}

private record Entity(String id, List<String> roles, String name, boolean user) {
@Data
private static class Entity {

private final String id;
private final List<String> roles;
private final String name;
private final boolean user;
private final Map<String, CallInfo> spanIdToCallInfo = new HashMap<>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move the logic related to token tracking to a separate class. e.g. TokenTracker


public Entity(String id, List<String> roles, String name, boolean user) {
this.id = id;
this.roles = roles;
this.name = name;
this.user = user;
}

public synchronized void addCall(String spanId) {
spanIdToCallInfo.put(spanId, new CallInfo());
}

public synchronized void link(String parentSpanId, String childSpanId) {
CallInfo callInfo = spanIdToCallInfo.get(parentSpanId);
if (callInfo == null) {
log.warn("Parent span is not found by id: {}", parentSpanId);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please improve logging in this class.

return;
}
callInfo.childSpanIds.add(childSpanId);
}

public synchronized void setTokeUsage(String spanId, TokenUsage tokenUsage) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please correct type: setTokenUsage

CallInfo callInfo = spanIdToCallInfo.get(spanId);
if (callInfo == null) {
log.warn("Span is not found by id: {}", spanId);
return;
}
callInfo.setTokenUsage(tokenUsage);
}

public synchronized TokenUsage calculate(String spanId) {
CallInfo callInfo = spanIdToCallInfo.get(spanId);
if (callInfo == null) {
log.warn("Span is not found by id: {}", spanId);
return null;
}
TokenUsage tokenUsage = callInfo.tokenUsage;
for (String childSpanId : callInfo.childSpanIds) {
CallInfo childCall = spanIdToCallInfo.get(childSpanId);
if (childCall == null) {
log.warn("Child span is not found by id: {}", childSpanId);
continue;
}
tokenUsage.increase(childCall.tokenUsage);
}
return tokenUsage;
}

@Override
public String toString() {
return String.format("Entity: %s, resource: %s, user: %b", id, roles, user);
}
}

@Data
private static class CallInfo {
TokenUsage tokenUsage = new TokenUsage();
List<String> childSpanIds = new ArrayList<>();
}

}
16 changes: 14 additions & 2 deletions src/main/java/com/epam/aidial/core/log/GfLogStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.epam.aidial.core.Proxy;
import com.epam.aidial.core.ProxyContext;
import com.epam.aidial.core.token.TokenUsage;
import com.epam.aidial.core.util.HttpStatus;
import com.epam.deltix.gflog.api.Log;
import com.epam.deltix.gflog.api.LogEntry;
Expand Down Expand Up @@ -94,8 +95,19 @@ private void append(ProxyContext context, LogEntry entry) {

append(entry, "\",\"body\":\"", false);
append(entry, context.getResponseBody());

append(entry, "\"}}", false);
TokenUsage tokenUsage = context.getTokenUsage();
if (tokenUsage != null) {
append(entry, "\"},\"tokenUsage\":{", false);
append(entry, "\"completion_tokens\":", false);
append(entry, Long.toString(tokenUsage.getCompletionTokens()), true);
append(entry, ",\"prompt_tokens\":", false);
append(entry, Long.toString(tokenUsage.getPromptTokens()), true);
append(entry, ",\"total_tokens\":", false);
append(entry, Long.toString(tokenUsage.getTotalTokens()), true);
append(entry, "}}", false);
} else {
append(entry, "\"}}", false);
}
}


Expand Down
9 changes: 9 additions & 0 deletions src/main/java/com/epam/aidial/core/token/TokenUsage.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ public class TokenUsage {
private long promptTokens;
private long totalTokens;

public void increase(TokenUsage usage) {
if (usage == null) {
return;
}
completionTokens += usage.completionTokens;
promptTokens += usage.promptTokens;
totalTokens += usage.totalTokens;
}

@Override
public String toString() {
return "completion=" + completionTokens
Expand Down
Loading