Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor base request functions #579

Merged
merged 2 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,16 @@ void handleRequestBody(Buffer requestBody) {

try (InputStream stream = new ByteBufInputStream(requestBody.getByteBuf())) {
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream);
Throwable error = ProxyUtil.processChain(tree, enhancementFunctions);
if (error != null) {
finalizeRequest();
return;
if (ProxyUtil.processChain(tree, enhancementFunctions)) {
context.setRequestBody(Buffer.buffer(ProxyUtil.MAPPER.writeValueAsBytes(tree)));
}
} catch (IOException e) {
respond(HttpStatus.BAD_REQUEST);
log.warn("Can't parse JSON request body. Trace: {}. Span: {}. Error:",
} catch (Throwable e) {
if (e instanceof HttpException httpException) {
respond(httpException.getStatus(), httpException.getMessage());
} else {
respond(HttpStatus.BAD_REQUEST);
}
log.warn("Can't process JSON request body. Trace: {}. Span: {}. Error:",
context.getTraceId(), context.getSpanId(), e);
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.epam.aidial.core.server.function.CollectResponseAttachmentsFn;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.server.vertx.stream.BufferingReadStream;
import com.epam.aidial.core.storage.http.HttpException;
import com.epam.aidial.core.storage.http.HttpStatus;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.netty.buffer.ByteBufInputStream;
Expand Down Expand Up @@ -69,14 +70,16 @@ private void handleRequestBody(Buffer requestBody) {
context.setRequestBodyTimestamp(System.currentTimeMillis());
try (InputStream stream = new ByteBufInputStream(requestBody.getByteBuf())) {
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream);
Throwable error = ProxyUtil.processChain(tree, enhancementFunctions);
if (error != null) {
finalizeRequest();
return;
if (ProxyUtil.processChain(tree, enhancementFunctions)) {
context.setRequestBody(Buffer.buffer(ProxyUtil.MAPPER.writeValueAsBytes(tree)));
}
} catch (IOException e) {
respond(HttpStatus.BAD_REQUEST);
log.warn("Can't parse JSON request body. Trace: {}. Span: {}. Error:",
} catch (Throwable e) {
if (e instanceof HttpException httpException) {
respond(httpException.getStatus(), httpException.getMessage());
} else {
respond(HttpStatus.BAD_REQUEST);
}
log.warn("Can't process JSON request body. Trace: {}. Span: {}. Error:",
context.getTraceId(), context.getSpanId(), e);
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import com.epam.aidial.core.server.Proxy;
import com.epam.aidial.core.server.ProxyContext;

public abstract class BaseRequestFunction<T> extends BaseFunction<T, Throwable> {
public abstract class BaseRequestFunction<T> extends BaseFunction<T, Boolean> {


public BaseRequestFunction(Proxy proxy, ProxyContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,12 @@ public CollectRequestAttachmentsFn(Proxy proxy, ProxyContext context) {
}

@Override
public Throwable apply(ObjectNode tree) {
try {
ProxyUtil.collectAttachedFilesFromRequest(tree, this::processAttachedFile);
// assign api key data after processing attachments
ApiKeyData destApiKeyData = context.getProxyApiKeyData();
proxy.getApiKeyStore().assignPerRequestApiKey(destApiKeyData);
return null;
} catch (HttpException e) {
context.respond(e.getStatus(), e.getMessage());
log.warn("Can't collect attached files. Trace: {}. Span: {}. Error: {}",
context.getTraceId(), context.getSpanId(), e.getMessage());
return e;
} catch (Throwable e) {
context.respond(HttpStatus.BAD_REQUEST);
log.warn("Can't collect attached files. Trace: {}. Span: {}. Error: {}",
context.getTraceId(), context.getSpanId(), e.getMessage());
return e;
}
public Boolean apply(ObjectNode tree) {
ProxyUtil.collectAttachedFilesFromRequest(tree, this::processAttachedFile);
// assign api key data after processing attachments
ApiKeyData destApiKeyData = context.getProxyApiKeyData();
proxy.getApiKeyStore().assignPerRequestApiKey(destApiKeyData);
return false;
}

private void processAttachedFile(String url) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ public CollectRequestDataFn(Proxy proxy, ProxyContext context) {
}

@Override
public Throwable apply(ObjectNode tree) {
public Boolean apply(ObjectNode tree) {
JsonNode stream = tree.get("stream");
boolean result = stream != null && stream.asBoolean(false);
context.setStreamingRequest(result);
return null;
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.function.BaseRequestFunction;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.storage.http.HttpStatus;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.vertx.core.buffer.Buffer;
import lombok.extern.slf4j.Slf4j;

import java.util.Map;
Expand All @@ -20,21 +18,7 @@ public ApplyDefaultDeploymentSettingsFn(Proxy proxy, ProxyContext context) {
}

@Override
public Throwable apply(ObjectNode tree) {
try {
if (applyDefaults(context, tree)) {
context.setRequestBody(Buffer.buffer(ProxyUtil.MAPPER.writeValueAsBytes(tree)));
}
return null;
} catch (Throwable e) {
context.respond(HttpStatus.BAD_REQUEST);
log.warn("Can't apply default parameters to deployment {}. Trace: {}. Span: {}. Error: {}",
context.getDeployment().getName(), context.getTraceId(), context.getSpanId(), e.getMessage());
return e;
}
}

private static boolean applyDefaults(ProxyContext context, ObjectNode tree) {
public Boolean apply(ObjectNode tree) {
Deployment deployment = context.getDeployment();
boolean applied = false;
for (Map.Entry<String, Object> e : deployment.getDefaults().entrySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.controller.DeploymentController;
import com.epam.aidial.core.server.function.BaseRequestFunction;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.storage.http.HttpException;
import com.epam.aidial.core.storage.http.HttpStatus;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.vertx.core.buffer.Buffer;
import lombok.extern.slf4j.Slf4j;

import java.util.HashMap;
Expand All @@ -30,30 +28,16 @@ public EnhanceAssistantRequestFn(Proxy proxy, ProxyContext context) {
}

@Override
public Throwable apply(ObjectNode tree) {
public Boolean apply(ObjectNode tree) {
Deployment deployment = context.getDeployment();
if (deployment instanceof Assistant) {
try {
Map.Entry<Buffer, Map<String, String>> enhancedRequest = enhanceAssistantRequest(context, tree);
context.setRequestBody(enhancedRequest.getKey());
context.setRequestHeaders(enhancedRequest.getValue());
} catch (HttpException e) {
context.respond(e.getStatus(), e.getMessage());
log.warn("Can't enhance assistant request. Trace: {}. Span: {}. Error: {}",
context.getTraceId(), context.getSpanId(), e.getMessage());
return e;
} catch (Throwable e) {
context.respond(HttpStatus.BAD_REQUEST);
log.warn("Can't enhance assistant request. Trace: {}. Span: {}. Error: {}",
context.getTraceId(), context.getSpanId(), e.getMessage());
return e;
}
enhanceAssistantRequest(context, tree);
return true;
}
return null;
return false;
}

private static Map.Entry<Buffer, Map<String, String>> enhanceAssistantRequest(ProxyContext context, ObjectNode tree)
throws Exception {
private static void enhanceAssistantRequest(ProxyContext context, ObjectNode tree) {
Config config = context.getConfig();
Assistant assistant = (Assistant) context.getDeployment();

Expand Down Expand Up @@ -108,8 +92,7 @@ private static Map.Entry<Buffer, Map<String, String>> enhanceAssistantRequest(Pr
throw new HttpException(HttpStatus.FORBIDDEN, "Forbidden model: " + name);
}

Buffer updatedBody = Buffer.buffer(ProxyUtil.MAPPER.writeValueAsBytes(tree));
return Map.entry(updatedBody, headers);
context.setRequestHeaders(headers);
}

private static void deletePrompt(ArrayNode messages) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import com.epam.aidial.core.server.Proxy;
import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.function.BaseRequestFunction;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.storage.http.HttpStatus;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.vertx.core.buffer.Buffer;
import lombok.extern.slf4j.Slf4j;

@Slf4j
Expand All @@ -18,33 +15,25 @@ public EnhanceModelRequestFn(Proxy proxy, ProxyContext context) {
}

@Override
public Throwable apply(ObjectNode tree) {
public Boolean apply(ObjectNode tree) {
Deployment deployment = context.getDeployment();
if (deployment instanceof Model) {
try {
context.setRequestBody(enhanceModelRequest(context, tree));
} catch (Throwable e) {
context.respond(HttpStatus.BAD_REQUEST);
log.warn("Can't enhance model request. Trace: {}. Span: {}. Error: {}",
context.getTraceId(), context.getSpanId(), e.getMessage());
return e;
}
return enhanceModelRequest(context, tree);
}
return null;
return false;
}

private static Buffer enhanceModelRequest(ProxyContext context, ObjectNode tree) throws Exception {
private static boolean enhanceModelRequest(ProxyContext context, ObjectNode tree) {
Model model = (Model) context.getDeployment();
String overrideName = model.getOverrideName();
Buffer requestBody = context.getRequestBody();

if (overrideName == null) {
return requestBody;
return false;
}

tree.remove("model");
tree.put("model", overrideName);

return Buffer.buffer(ProxyUtil.MAPPER.writeValueAsBytes(tree));
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,14 @@ public static String convertToString(Object data) {
}
}

public static <T> Throwable processChain(T item, List<BaseRequestFunction<T>> chain) {
public static <T> boolean processChain(T item, List<BaseRequestFunction<T>> chain) {
boolean result = false;
for (BaseRequestFunction<T> fn : chain) {
Throwable error = fn.apply(item);
if (error != null) {
return error;
if (fn.apply(item)) {
result = true;
}
}
return null;
return result;
}

public static EtagHeader etag(HttpServerRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ public void test_01() throws JsonProcessingException {
doCallRealMethod().when(context).setStreamingRequest(anyBoolean());
when(context.isStreamingRequest()).thenCallRealMethod();

Throwable error = fn.apply((ObjectNode) ProxyUtil.MAPPER.readTree("{\"stream\": true}"));
boolean result = fn.apply((ObjectNode) ProxyUtil.MAPPER.readTree("{\"stream\": true}"));

assertNull(error);
assertFalse(result);
assertTrue(context.isStreamingRequest());
}

Expand All @@ -46,9 +46,9 @@ public void test_02() throws JsonProcessingException {
doCallRealMethod().when(context).setStreamingRequest(anyBoolean());
when(context.isStreamingRequest()).thenCallRealMethod();

Throwable error = fn.apply((ObjectNode) ProxyUtil.MAPPER.readTree("{\"stream\": false}"));
boolean result = fn.apply((ObjectNode) ProxyUtil.MAPPER.readTree("{\"stream\": false}"));

assertNull(error);
assertFalse(result);
assertFalse(context.isStreamingRequest());
}

Expand All @@ -57,9 +57,9 @@ public void test_03() throws JsonProcessingException {
doCallRealMethod().when(context).setStreamingRequest(anyBoolean());
when(context.isStreamingRequest()).thenCallRealMethod();

Throwable error = fn.apply((ObjectNode) ProxyUtil.MAPPER.readTree("{}"));
boolean result = fn.apply((ObjectNode) ProxyUtil.MAPPER.readTree("{}"));

assertNull(error);
assertFalse(result);
assertFalse(context.isStreamingRequest());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.vertx.core.buffer.Buffer;
import org.junit.jupiter.api.Test;
Expand All @@ -18,6 +19,7 @@
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand All @@ -42,12 +44,8 @@ public void test() throws JsonProcessingException {
Map<String, Object> defaults = Map.of("key1", true, "key2", 123, "key3", 0.45, "key4", "str");
model.setDefaults(defaults);
when(context.getDeployment()).thenReturn(model);
Mockito.doCallRealMethod().when(context).setRequestBody(any(Buffer.class));
when(context.getRequestBody()).thenCallRealMethod();
Throwable error = fn.apply((ObjectNode) ProxyUtil.MAPPER.readTree("{}"));
assertNull(error);
String json = context.getRequestBody().toString(StandardCharsets.UTF_8);
ObjectNode result = (ObjectNode) ProxyUtil.MAPPER.readTree(json);
JsonNode result = ProxyUtil.MAPPER.readTree("{}");
assertTrue(fn.apply((ObjectNode) result));
assertNotNull(result);
assertEquals(123, result.get("key2").asInt());
assertEquals(0.45, result.get("key3").asDouble());
Expand Down
Loading