diff --git a/src/main/java/com/epam/aidial/core/util/ModelCostCalculator.java b/src/main/java/com/epam/aidial/core/util/ModelCostCalculator.java index 3e6ab0c8e..f3fb1b9d3 100644 --- a/src/main/java/com/epam/aidial/core/util/ModelCostCalculator.java +++ b/src/main/java/com/epam/aidial/core/util/ModelCostCalculator.java @@ -3,12 +3,20 @@ import com.epam.aidial.core.ProxyContext; import com.epam.aidial.core.config.Deployment; import com.epam.aidial.core.config.Model; +import com.epam.aidial.core.config.ModelType; import com.epam.aidial.core.config.Pricing; import com.epam.aidial.core.token.TokenUsage; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import io.netty.buffer.ByteBufInputStream; +import io.vertx.core.buffer.Buffer; import lombok.experimental.UtilityClass; import lombok.extern.slf4j.Slf4j; +import java.io.InputStream; import java.math.BigDecimal; +import java.util.Scanner; @Slf4j @UtilityClass @@ -25,7 +33,12 @@ public static BigDecimal calculate(ProxyContext context) { return null; } - return calculate(context.getTokenUsage(), pricing.getPrompt(), pricing.getCompletion()); + return switch (pricing.getUnit()) { + case "token" -> calculate(context.getTokenUsage(), pricing.getPrompt(), pricing.getCompletion()); + case "char_without_whitespace" -> + calculate(model.getType(), context.getRequestBody(), context.getResponseBody(), pricing.getPrompt(), pricing.getCompletion()); + default -> null; + }; } private static BigDecimal calculate(TokenUsage tokenUsage, String promptRate, String completionRate) { @@ -47,4 +60,112 @@ private static BigDecimal calculate(TokenUsage tokenUsage, String promptRate, St return cost; } + private static BigDecimal calculate(ModelType modelType, Buffer requestBody, Buffer responseBody, String promptRate, String completionRate) { + RequestLengthResult requestLengthResult = getRequestContentLength(modelType, requestBody); + int responseLength = getResponseContentLength(modelType, responseBody, requestLengthResult.stream()); + BigDecimal cost = null; + if (promptRate != null) { + cost = new BigDecimal(requestLengthResult.length()).multiply(new BigDecimal(promptRate)); + } + if (completionRate != null) { + BigDecimal completionCost = new BigDecimal(responseLength).multiply(new BigDecimal(completionRate)); + if (cost == null) { + cost = completionCost; + } else { + cost = cost.add(completionCost); + } + } + return cost; + } + + private static int getResponseContentLength(ModelType modelType, Buffer responseBody, boolean isStreamingResponse) { + if (modelType == ModelType.EMBEDDING) { + return 0; + } + if (isStreamingResponse) { + try (Scanner scanner = new Scanner(new ByteBufInputStream(responseBody.getByteBuf()))) { + scanner.useDelimiter("\n*data: *"); + int len = 0; + while (scanner.hasNext()) { + String chunk = scanner.next(); + if (chunk.startsWith("[DONE]")) { + break; + } + ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(chunk); + ArrayNode choices = (ArrayNode) tree.get("choices"); + if (choices == null) { + // skip error message + continue; + } + JsonNode contentNode = choices.get(0).get("delta").get("content"); + if (contentNode != null) { + len += getLengthWithoutWhitespace(contentNode.textValue()); + } + } + return len; + } catch (Throwable e) { + throw new RuntimeException(e); + } + } else { + try (InputStream stream = new ByteBufInputStream(responseBody.getByteBuf())) { + ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream); + ArrayNode choices = (ArrayNode) tree.get("choices"); + if (choices == null) { + // skip error message + return 0; + } + JsonNode contentNode = choices.get(0).get("message").get("content"); + return getLengthWithoutWhitespace(contentNode.textValue()); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + } + + private static RequestLengthResult getRequestContentLength(ModelType modelType, Buffer requestBody) { + try (InputStream stream = new ByteBufInputStream(requestBody.getByteBuf())) { + int len; + ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream); + if (modelType == ModelType.CHAT) { + ArrayNode messages = (ArrayNode) tree.get("messages"); + len = 0; + for (int i = 0; i < messages.size(); i++) { + JsonNode message = messages.get(i); + len += getLengthWithoutWhitespace(message.get("content").textValue()); + } + return new RequestLengthResult(len, tree.get("stream").asBoolean(false)); + } else { + JsonNode input = tree.get("input"); + if (input instanceof ArrayNode array) { + len = 0; + for (int i = 0; i < array.size(); i++) { + len += getLengthWithoutWhitespace(array.get(i).textValue()); + } + } else { + len = getLengthWithoutWhitespace(input.textValue()); + } + } + return new RequestLengthResult(len, false); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + private static int getLengthWithoutWhitespace(String s) { + if (s == null) { + return 0; + } + int len = 0; + for (int i = 0; i < s.length(); i++) { + if (s.charAt(i) != ' ') { + len++; + } + } + return len; + } + + private record RequestLengthResult(int length, boolean stream) { + + } + } diff --git a/src/test/java/com/epam/aidial/core/util/ModelCostCalculatorTest.java b/src/test/java/com/epam/aidial/core/util/ModelCostCalculatorTest.java index 0f30a5b07..8ddf8bc6b 100644 --- a/src/test/java/com/epam/aidial/core/util/ModelCostCalculatorTest.java +++ b/src/test/java/com/epam/aidial/core/util/ModelCostCalculatorTest.java @@ -2,8 +2,10 @@ import com.epam.aidial.core.ProxyContext; import com.epam.aidial.core.config.Model; +import com.epam.aidial.core.config.ModelType; import com.epam.aidial.core.config.Pricing; import com.epam.aidial.core.token.TokenUsage; +import io.vertx.core.buffer.Buffer; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -61,4 +63,250 @@ public void testCalculate_TokenCost() { assertEquals(new BigDecimal("6.0"), ModelCostCalculator.calculate(context)); } + @Test + public void testCalculate_LengthCost_Chat_StreamIsFalse_Success() { + Model model = new Model(); + model.setType(ModelType.CHAT); + Pricing pricing = new Pricing(); + pricing.setPrompt("0.1"); + pricing.setCompletion("0.5"); + pricing.setUnit("char_without_whitespace"); + model.setPricing(pricing); + when(context.getDeployment()).thenReturn(model); + + String response = """ + { + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "A file is a named collection." + } + } + ], + "usage": { + "prompt_tokens": 4, + "completion_tokens": 343, + "total_tokens": 347 + }, + "id": "fd3be95a-c208-4dca-90cf-67e5082a4e5b", + "created": 1705319789, + "object": "chat.completion" + } + """; + when(context.getResponseBody()).thenReturn(Buffer.buffer(response)); + + String request = """ + { + "messages": [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": "How are you?" + } + ], + "max_tokens": 500, + "temperature": 1, + "stream": false + } + """; + when(context.getRequestBody()).thenReturn(Buffer.buffer(request)); + + assertEquals(new BigDecimal("13.0"), ModelCostCalculator.calculate(context)); + } + + @Test + public void testCalculate_LengthCost_Chat_StreamIsFalse_Error() { + Model model = new Model(); + model.setType(ModelType.CHAT); + Pricing pricing = new Pricing(); + pricing.setPrompt("0.1"); + pricing.setCompletion("0.5"); + pricing.setUnit("char_without_whitespace"); + model.setPricing(pricing); + when(context.getDeployment()).thenReturn(model); + + String response = """ + {"error": { "message": "message", "type": "type", "param": "param", "code": "code" } } + """; + when(context.getResponseBody()).thenReturn(Buffer.buffer(response)); + + String request = """ + { + "messages": [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": "How are you?" + } + ], + "max_tokens": 500, + "temperature": 1, + "stream": false + } + """; + when(context.getRequestBody()).thenReturn(Buffer.buffer(request)); + + assertEquals(new BigDecimal("1.0"), ModelCostCalculator.calculate(context)); + } + + @Test + public void testCalculate_LengthCost_Chat_StreamIsTrue_Success() { + Model model = new Model(); + model.setType(ModelType.CHAT); + Pricing pricing = new Pricing(); + pricing.setPrompt("0.1"); + pricing.setCompletion("0.5"); + pricing.setUnit("char_without_whitespace"); + model.setPricing(pricing); + when(context.getDeployment()).thenReturn(model); + + String response = """ + data: {"id":"chatcmpl-7VfCSOSOS1gYQbDFiEMyh71RJSy1m","object":"chat.completion.chunk","created":1687780896,"model":"gpt-35-turbo","choices":[{"index":0,"finish_reason":null,"delta":{"role":"assistant"}}],"usage":null} + + data: {"id":"chatcmpl-7VfCSOSOS1gYQbDFiEMyh71RJSy1m","object":"chat.completion.chunk","created":1687780896,"model":"gpt-35-turbo","choices":[{"index":0,"finish_reason":null,"delta":{"content":"this"}}],"usage":null} + + data: {"id":"chatcmpl-7VfCSOSOS1gYQbDFiEMyh71RJSy1m","object":"chat.completion.chunk","created":1687780896,"model":"gpt-35-turbo","choices":[{"index":0,"finish_reason":null,"delta":{"content":" is "}}],"usage":null} + + + + data: {"id":"chatcmpl-7VfCSOSOS1gYQbDFiEMyh71RJSy1m","object":"chat.completion.chunk","created":1687780896,"model":"gpt-35-turbo","choices":[{"index":0,"finish_reason":null,"delta":{"content":"a text"}}],"usage":null} + + data: [DONE] + + + """; + when(context.getResponseBody()).thenReturn(Buffer.buffer(response)); + + String request = """ + { + "messages": [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": "How are you?" + } + ], + "max_tokens": 500, + "temperature": 1, + "stream": true + } + """; + when(context.getRequestBody()).thenReturn(Buffer.buffer(request)); + + assertEquals(new BigDecimal("6.5"), ModelCostCalculator.calculate(context)); + } + + @Test + public void testCalculate_LengthCost_Chat_StreamIsTrue_Error() { + Model model = new Model(); + model.setType(ModelType.CHAT); + Pricing pricing = new Pricing(); + pricing.setPrompt("0.1"); + pricing.setCompletion("0.5"); + pricing.setUnit("char_without_whitespace"); + model.setPricing(pricing); + when(context.getDeployment()).thenReturn(model); + + String response = """ + data: {"id":"chatcmpl-7VfCSOSOS1gYQbDFiEMyh71RJSy1m","object":"chat.completion.chunk","created":1687780896,"model":"gpt-35-turbo","choices":[{"index":0,"finish_reason":null,"delta":{"role":"assistant"}}],"usage":null} + + data: {"id":"chatcmpl-7VfCSOSOS1gYQbDFiEMyh71RJSy1m","object":"chat.completion.chunk","created":1687780896,"model":"gpt-35-turbo","choices":[{"index":0,"finish_reason":null,"delta":{"content":"this"}}],"usage":null} + + data: {"error": { "message": "message", "type": "type", "param": "param", "code": "code" } } + + + + data: {"id":"chatcmpl-7VfCSOSOS1gYQbDFiEMyh71RJSy1m","object":"chat.completion.chunk","created":1687780896,"model":"gpt-35-turbo","choices":[{"index":0,"finish_reason":null,"delta":{"content":"a text"}}],"usage":null} + + data: [DONE] + + + """; + when(context.getResponseBody()).thenReturn(Buffer.buffer(response)); + + String request = """ + { + "messages": [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": "How are you?" + } + ], + "max_tokens": 500, + "temperature": 1, + "stream": true + } + """; + when(context.getRequestBody()).thenReturn(Buffer.buffer(request)); + + assertEquals(new BigDecimal("5.5"), ModelCostCalculator.calculate(context)); + } + + @Test + public void testCalculate_LengthCost_EmbeddingInputIsArray() { + Model model = new Model(); + model.setType(ModelType.EMBEDDING); + Pricing pricing = new Pricing(); + pricing.setPrompt("0.1"); + pricing.setCompletion("0.5"); + pricing.setUnit("char_without_whitespace"); + model.setPricing(pricing); + when(context.getDeployment()).thenReturn(model); + + String response = """ + {} + """; + when(context.getResponseBody()).thenReturn(Buffer.buffer(response)); + + String request = """ + { + "input": ["text", "123"] + } + """; + when(context.getRequestBody()).thenReturn(Buffer.buffer(request)); + + assertEquals(new BigDecimal("0.7"), ModelCostCalculator.calculate(context)); + } + + @Test + public void testCalculate_LengthCost_EmbeddingInputIsString() { + Model model = new Model(); + model.setType(ModelType.EMBEDDING); + Pricing pricing = new Pricing(); + pricing.setPrompt("0.1"); + pricing.setCompletion("0.5"); + pricing.setUnit("char_without_whitespace"); + model.setPricing(pricing); + when(context.getDeployment()).thenReturn(model); + + String response = """ + {} + """; + when(context.getResponseBody()).thenReturn(Buffer.buffer(response)); + + String request = """ + { + "input": "text" + } + """; + when(context.getRequestBody()).thenReturn(Buffer.buffer(request)); + + assertEquals(new BigDecimal("0.4"), ModelCostCalculator.calculate(context)); + } }