-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Calculate token usage statistics using call stack #117 #118
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,11 +7,14 @@ | |
import com.epam.aidial.core.config.Role; | ||
import com.epam.aidial.core.token.TokenUsage; | ||
import com.epam.aidial.core.util.HttpStatus; | ||
import io.opentelemetry.api.trace.Span; | ||
import lombok.Data; | ||
import lombok.extern.slf4j.Slf4j; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
|
||
@Slf4j | ||
|
@@ -21,7 +24,7 @@ public class RateLimiter { | |
|
||
public void increase(ProxyContext context) { | ||
Entity entity = getEntityFromTracingContext(context); | ||
if (entity == null || entity.user()) { | ||
if (entity == null || entity.isUser()) { | ||
return; | ||
} | ||
Deployment deployment = context.getDeployment(); | ||
|
@@ -31,7 +34,9 @@ public void increase(ProxyContext context) { | |
return; | ||
} | ||
|
||
Id id = new Id(entity.id(), deployment.getName(), entity.user()); | ||
entity.setTokeUsage(context.getCurrentSpanId(), usage); | ||
|
||
Id id = new Id(entity.getId(), deployment.getName(), entity.isUser()); | ||
RateLimit rate = rates.computeIfAbsent(id, k -> new RateLimit()); | ||
|
||
long timestamp = System.currentTimeMillis(); | ||
|
@@ -41,12 +46,11 @@ public void increase(ProxyContext context) { | |
public RateLimitResult limit(ProxyContext context) { | ||
Entity entity = getEntityFromTracingContext(context); | ||
if (entity == null) { | ||
Span span = Span.current(); | ||
log.warn("Entity is not found by traceId={}", span.getSpanContext().getTraceId()); | ||
log.warn("Entity is not found by traceId={}", context.getTraceId()); | ||
return new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied"); | ||
} | ||
Limit limit; | ||
if (entity.user()) { | ||
if (entity.isUser()) { | ||
// don't support user limits yet | ||
return RateLimitResult.SUCCESS; | ||
} else { | ||
|
@@ -64,7 +68,7 @@ public RateLimitResult limit(ProxyContext context) { | |
return new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied"); | ||
} | ||
|
||
Id id = new Id(entity.id(), deployment.getName(), entity.user()); | ||
Id id = new Id(entity.getId(), deployment.getName(), entity.isUser()); | ||
RateLimit rate = rates.get(id); | ||
|
||
if (rate == null) { | ||
|
@@ -81,22 +85,36 @@ public RateLimitResult limit(ProxyContext context) { | |
public boolean register(ProxyContext context) { | ||
String traceId = context.getTraceId(); | ||
Entity entity = traceIdToEntity.get(traceId); | ||
if (entity != null) { | ||
boolean result = entity != null; | ||
if (result) { | ||
entity.link(context.getParentSpanId(), context.getCurrentSpanId()); | ||
// update context with the original requester | ||
if (entity.user()) { | ||
context.setUserHash(entity.name()); | ||
if (entity.isUser()) { | ||
context.setUserHash(entity.getName()); | ||
} else { | ||
context.setOriginalProject(entity.name()); | ||
context.setOriginalProject(entity.getName()); | ||
} | ||
} else { | ||
if (context.getKey() != null) { | ||
Key key = context.getKey(); | ||
traceIdToEntity.put(traceId, new Entity(key.getKey(), Collections.singletonList(key.getRole()), key.getProject(), false)); | ||
entity = new Entity(key.getKey(), Collections.singletonList(key.getRole()), key.getProject(), false); | ||
traceIdToEntity.put(traceId, entity); | ||
} else { | ||
traceIdToEntity.put(traceId, new Entity(context.getUserSub(), context.getUserRoles(), context.getUserHash(), true)); | ||
entity = new Entity(context.getUserSub(), context.getUserRoles(), context.getUserHash(), true); | ||
traceIdToEntity.put(traceId, entity); | ||
} | ||
} | ||
return entity != null; | ||
entity.addCall(context.getCurrentSpanId()); | ||
return result; | ||
} | ||
|
||
public void calculateTokenUsage(ProxyContext context) { | ||
Entity entity = getEntityFromTracingContext(context); | ||
if (entity == null) { | ||
log.warn("Entity is not found by traceId={}", context.getTraceId()); | ||
return; | ||
} | ||
context.setTokenUsage(entity.calculate(context.getCurrentSpanId())); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it better to do it only if it is missing? Some applications/models can compute it more precisely. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can rely on applications/assistants any longer to compute statistics |
||
} | ||
|
||
public void unregister(ProxyContext context) { | ||
|
@@ -129,11 +147,72 @@ public String toString() { | |
} | ||
} | ||
|
||
private record Entity(String id, List<String> roles, String name, boolean user) { | ||
@Data | ||
private static class Entity { | ||
|
||
private final String id; | ||
private final List<String> roles; | ||
private final String name; | ||
private final boolean user; | ||
private final Map<String, CallInfo> spanIdToCallInfo = new HashMap<>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please move the logic related to token tracking to a separate class. e.g. TokenTracker |
||
|
||
public Entity(String id, List<String> roles, String name, boolean user) { | ||
this.id = id; | ||
this.roles = roles; | ||
this.name = name; | ||
this.user = user; | ||
} | ||
|
||
public synchronized void addCall(String spanId) { | ||
spanIdToCallInfo.put(spanId, new CallInfo()); | ||
} | ||
|
||
public synchronized void link(String parentSpanId, String childSpanId) { | ||
CallInfo callInfo = spanIdToCallInfo.get(parentSpanId); | ||
if (callInfo == null) { | ||
log.warn("Parent span is not found by id: {}", parentSpanId); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please improve logging in this class. |
||
return; | ||
} | ||
callInfo.childSpanIds.add(childSpanId); | ||
} | ||
|
||
public synchronized void setTokeUsage(String spanId, TokenUsage tokenUsage) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please correct type: setTokenUsage |
||
CallInfo callInfo = spanIdToCallInfo.get(spanId); | ||
if (callInfo == null) { | ||
log.warn("Span is not found by id: {}", spanId); | ||
return; | ||
} | ||
callInfo.setTokenUsage(tokenUsage); | ||
} | ||
|
||
public synchronized TokenUsage calculate(String spanId) { | ||
CallInfo callInfo = spanIdToCallInfo.get(spanId); | ||
if (callInfo == null) { | ||
log.warn("Span is not found by id: {}", spanId); | ||
return null; | ||
} | ||
TokenUsage tokenUsage = callInfo.tokenUsage; | ||
for (String childSpanId : callInfo.childSpanIds) { | ||
CallInfo childCall = spanIdToCallInfo.get(childSpanId); | ||
if (childCall == null) { | ||
log.warn("Child span is not found by id: {}", childSpanId); | ||
continue; | ||
} | ||
tokenUsage.increase(childCall.tokenUsage); | ||
} | ||
return tokenUsage; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return String.format("Entity: %s, resource: %s, user: %b", id, roles, user); | ||
} | ||
} | ||
|
||
@Data | ||
private static class CallInfo { | ||
TokenUsage tokenUsage = new TokenUsage(); | ||
List<String> childSpanIds = new ArrayList<>(); | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add prefix to locate the problem more easily: "Failed to calculate token usage"