Skip to content

Commit

Permalink
fix: API key is not sent to Base Assistant #88 (#89)
Browse files Browse the repository at this point in the history
Co-authored-by: Aliaksandr Stsiapanay <[email protected]>
  • Loading branch information
astsiapanay and astsiapanay authored Dec 15, 2023
1 parent 873f0a7 commit 4145de2
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 38 deletions.
25 changes: 22 additions & 3 deletions src/main/java/com/epam/aidial/core/config/FileConfigStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.epam.aidial.core.util.ProxyUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.annotations.VisibleForTesting;
import io.vertx.core.Vertx;
import io.vertx.core.json.JsonObject;
import lombok.SneakyThrows;
Expand All @@ -11,8 +12,10 @@
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;

import static com.epam.aidial.core.config.Config.ASSISTANT;
import static com.epam.aidial.core.security.ApiKeyGenerator.generateKey;


Expand All @@ -21,6 +24,7 @@ public final class FileConfigStore implements ConfigStore {

private final String[] paths;
private volatile Config config;
private final Map<String, String> deploymentKeys = new HashMap<>();

public FileConfigStore(Vertx vertx, JsonObject settings) {
this.paths = settings.getJsonArray("files")
Expand Down Expand Up @@ -61,6 +65,12 @@ private void load(boolean fail) {
}

Assistants assistants = config.getAssistant();
// base assistant
if (assistants.getEndpoint() != null) {
Assistant baseAssistant = new Assistant();
baseAssistant.setName(ASSISTANT);
associateDeploymentWithApiKey(config, baseAssistant);
}
for (Map.Entry<String, Assistant> entry : assistants.getAssistants().entrySet()) {
String name = entry.getKey();
Assistant assistant = entry.getValue();
Expand Down Expand Up @@ -103,9 +113,17 @@ private void load(boolean fail) {
}
}

private void associateDeploymentWithApiKey(Config config, Deployment deployment) {
String apiKey = deployment.getApiKey() == null ? generateKey() : deployment.getApiKey();
while (config.getKeys().containsKey(apiKey)) {
@VisibleForTesting
void associateDeploymentWithApiKey(Config config, Deployment deployment) {
String apiKey = deployment.getApiKey();
String deploymentName = deployment.getName();
if (apiKey == null) {
apiKey = deploymentKeys.computeIfAbsent(deploymentName, k -> generateKey());
} else {
deploymentKeys.put(deploymentName, apiKey);
}
Map<String, Key> keys = config.getKeys();
while (keys.containsKey(apiKey) && !deploymentName.equals(keys.get(apiKey).getProject())) {
log.warn("duplicate API key is found for deployment {}. Trying to generate a new one", deployment.getName());
apiKey = generateKey();
}
Expand All @@ -114,6 +132,7 @@ private void associateDeploymentWithApiKey(Config config, Deployment deployment)
key.setKey(apiKey);
key.setProject(deployment.getName());
config.getKeys().put(apiKey, key);
deploymentKeys.put(deployment.getName(), apiKey);
}

private Config loadConfig() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBufInputStream;
import io.vertx.core.Future;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpClientResponse;
Expand All @@ -40,7 +42,6 @@

import java.io.InputStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -150,7 +151,8 @@ private void handleRequestBody(Buffer requestBody) {
/**
* Called when proxy connected to the origin.
*/
private void handleProxyRequest(HttpClientRequest proxyRequest) {
@VisibleForTesting
void handleProxyRequest(HttpClientRequest proxyRequest) {
log.info("Connected to origin. Key: {}. Deployment: {}. Address: {}", context.getProject(),
context.getDeployment().getName(), proxyRequest.connection().remoteAddress());

Expand All @@ -159,10 +161,10 @@ private void handleProxyRequest(HttpClientRequest proxyRequest) {
context.setProxyConnectTimestamp(System.currentTimeMillis());

Deployment deployment = context.getDeployment();
Set<CharSequence> excludeHeaders = new HashSet<>();
excludeHeaders.add(Proxy.HEADER_API_KEY);
MultiMap excludeHeaders = MultiMap.caseInsensitiveMultiMap();
excludeHeaders.add(Proxy.HEADER_API_KEY, "whatever");
if (!deployment.isForwardAuthToken()) {
excludeHeaders.add(HttpHeaders.AUTHORIZATION);
excludeHeaders.add(HttpHeaders.AUTHORIZATION, "whatever");
}

ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers(), excludeHeaders);
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/epam/aidial/core/util/ProxyUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ public class ProxyUtil {
.add(HttpHeaders.ACCEPT_ENCODING, "whatever");

public static void copyHeaders(MultiMap from, MultiMap to) {
copyHeaders(from, to, Collections.emptySet());
copyHeaders(from, to, MultiMap.caseInsensitiveMultiMap());
}

public static void copyHeaders(MultiMap from, MultiMap to, Set<CharSequence> excludeHeaders) {
public static void copyHeaders(MultiMap from, MultiMap to, MultiMap excludeHeaders) {
for (Map.Entry<String, String> entry : from.entries()) {
String key = entry.getKey();
String value = entry.getValue();
Expand Down
152 changes: 152 additions & 0 deletions src/test/java/com/epam/aidial/core/config/FileConfigStoreTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package com.epam.aidial.core.config;

import com.epam.aidial.core.AiDial;
import io.vertx.core.Vertx;
import io.vertx.core.json.JsonObject;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.mock;

public class FileConfigStoreTest {

private FileConfigStore store;

private static JsonObject SETTINGS;

@BeforeAll
public static void beforeAll() throws IOException {
String file = "aidial.settings.json";
try (InputStream stream = AiDial.class.getClassLoader().getResourceAsStream(file)) {
Objects.requireNonNull(stream, "Default resource file with settings is not found");
String json = new String(stream.readAllBytes(), StandardCharsets.UTF_8);
SETTINGS = new JsonObject(json);
}
}

@BeforeEach
public void beforeEach() {
store = new FileConfigStore(mock(Vertx.class), SETTINGS.getJsonObject("config"));
}

@Test
public void testAssociateDeploymentWithApiKey_DuplicateKey() {
Config config = new Config();
Key k1 = new Key();
k1.setProject("some");
config.getKeys().put("k1", k1);
Application app = new Application();
app.setName("app");
app.setApiKey("k1");
store.associateDeploymentWithApiKey(config, app);
assertNotNull(app.getApiKey());
assertNotEquals("k1", app.getApiKey());
assertEquals(2, config.getKeys().size());
assertEquals(app.getName(), config.getKeys().get(app.getApiKey()).getProject());
assertNotNull(config.getKeys().get(app.getApiKey()).getKey());
assertEquals(k1.getProject(), config.getKeys().get("k1").getProject());
}

@Test
public void testAssociateDeploymentWithApiKey_MissedKey() {
Config config = new Config();
Key k1 = new Key();
k1.setProject("some");
config.getKeys().put("k1", k1);
Application app = new Application();
app.setName("app");
store.associateDeploymentWithApiKey(config, app);
assertNotNull(app.getApiKey());
assertNotEquals("k1", app.getApiKey());
assertEquals(2, config.getKeys().size());
assertEquals(app.getName(), config.getKeys().get(app.getApiKey()).getProject());
assertNotNull(config.getKeys().get(app.getApiKey()).getKey());
assertEquals(k1.getProject(), config.getKeys().get("k1").getProject());
}

@Test
public void testAssociateDeploymentWithApiKey_DifferentKey() {
Config config = new Config();
Key k1 = new Key();
k1.setProject("some");
config.getKeys().put("k1", k1);
Application app = new Application();
app.setName("app");
app.setApiKey("k2");
store.associateDeploymentWithApiKey(config, app);
assertNotNull(app.getApiKey());
assertNotEquals("k1", app.getApiKey());
assertEquals(2, config.getKeys().size());
assertEquals(app.getName(), config.getKeys().get(app.getApiKey()).getProject());
assertNotNull(config.getKeys().get(app.getApiKey()).getKey());
assertEquals(k1.getProject(), config.getKeys().get("k1").getProject());
}

@Test
public void testAssociateDeploymentWithApiKey_Reload() {
String apiKey = null;
for (int i = 0; i < 3; i++) {
Config config = new Config();
Key k1 = new Key();
k1.setProject("some");
config.getKeys().put("k1", k1);
Application app = new Application();
app.setName("app");

store.associateDeploymentWithApiKey(config, app);

if (i == 0) {
apiKey = app.getApiKey();
} else {
assertEquals(apiKey, app.getApiKey());
}
assertNotNull(app.getApiKey());
assertNotEquals("k1", app.getApiKey());
assertEquals(2, config.getKeys().size());
assertEquals(app.getName(), config.getKeys().get(app.getApiKey()).getProject());
assertNotNull(config.getKeys().get(app.getApiKey()).getKey());
assertEquals(k1.getProject(), config.getKeys().get("k1").getProject());
}
}

@Test
public void testAssociateDeploymentWithApiKey_ReloadKeyChanged() {
String apiKey = null;
for (int i = 0; i < 5; i++) {
Config config = new Config();
Key k1 = new Key();
k1.setProject("some");
config.getKeys().put("k1", k1);
Application app = new Application();
app.setName("app");
if (i == 2) {
app.setApiKey("k2");
}
if (i == 4) {
app.setApiKey(null);
}
store.associateDeploymentWithApiKey(config, app);
assertNotNull(app.getApiKey());
assertNotEquals("k1", app.getApiKey());
assertEquals(2, config.getKeys().size());
assertEquals(app.getName(), config.getKeys().get(app.getApiKey()).getProject());
assertNotNull(config.getKeys().get(app.getApiKey()).getKey());
assertEquals(k1.getProject(), config.getKeys().get("k1").getProject());
switch (i) {
case 1 -> assertEquals(apiKey, app.getApiKey());
case 2, 3 -> assertEquals("k2", app.getApiKey());
case 4 -> assertNotEquals(apiKey, app.getApiKey());
default -> apiKey = app.getApiKey();
}
}
}
}
Loading

0 comments on commit 4145de2

Please sign in to comment.