Skip to content

Commit

Permalink
Merge branch 'development' into fix-file-metadata-url-encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Maksim_Hadalau authored and Maksim_Hadalau committed Dec 19, 2023
2 parents b8de7b6 + 82f3d92 commit 6aafc45
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/main/java/com/epam/aidial/core/config/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ public class Model extends Deployment {
private TokenLimits limits;
private Pricing pricing;
private List<Upstream> upstreams = List.of();
// if it's set then the model name is overridden with that name in the request body to the model adapter
private String overrideName;
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,13 @@ public Future<?> handle(String deploymentId, String deploymentApi) {
deployment = null;
}

if (deployment == null || (!isBaseAssistant(deployment) && !DeploymentController.hasAccess(context, deployment))) {
if (deployment == null) {
log.error("Deployment {} is not found", deploymentId);
return context.respond(HttpStatus.NOT_FOUND, "Deployment is not found");
}

if (!isBaseAssistant(deployment) && !DeploymentController.hasAccess(context, deployment)) {
log.error("Forbidden deployment {}. Key: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub());
return context.respond(HttpStatus.FORBIDDEN, "Forbidden deployment");
}

Expand All @@ -81,6 +87,7 @@ public Future<?> handle(String deploymentId, String deploymentApi) {
ErrorData rateLimitError = new ErrorData();
rateLimitError.getError().setCode(String.valueOf(rateLimitResult.status().getCode()));
rateLimitError.getError().setMessage(rateLimitResult.errorMessage());
log.error("Rate limit error {}. Key: {}. User sub: {}", rateLimitResult.errorMessage(), context.getProject(), context.getUserSub());
return context.respond(rateLimitResult.status(), rateLimitError);
}

Expand All @@ -92,6 +99,7 @@ public Future<?> handle(String deploymentId, String deploymentApi) {
context.setUpstreamRoute(endpointRoute);

if (!endpointRoute.hasNext()) {
log.error("No route. Key: {}. Deployment: {}. User sub: {}", context.getProject(), deploymentId, context.getUserSub());
return context.respond(HttpStatus.BAD_GATEWAY, "No route");
}

Expand All @@ -106,6 +114,7 @@ private Future<?> sendRequest() {
HttpServerRequest request = context.getRequest();

if (!route.hasNext()) {
log.error("No route. Key: {}. Deployment: {}. User sub: {}", context.getProject(), context.getDeployment().getName(), context.getUserSub());
return context.respond(HttpStatus.BAD_GATEWAY, "No route");
}

Expand All @@ -122,14 +131,16 @@ private Future<?> sendRequest() {
.onFailure(this::handleProxyConnectionError);
}

private void handleRequestBody(Buffer requestBody) {
@VisibleForTesting
void handleRequestBody(Buffer requestBody) {
Deployment deployment = context.getDeployment();
log.info("Received body from client. Key: {}. Deployment: {}. Length: {}", context.getProject(),
context.getDeployment().getName(), requestBody.length());
deployment.getName(), requestBody.length());

context.setRequestBody(requestBody);
context.setRequestBodyTimestamp(System.currentTimeMillis());

if (context.getDeployment() instanceof Assistant) {
if (deployment instanceof Assistant) {
try {
Map.Entry<Buffer, Map<String, String>> enhancedRequest = enhanceAssistantRequest(context);
context.setRequestBody(enhancedRequest.getKey());
Expand All @@ -145,6 +156,15 @@ private void handleRequestBody(Buffer requestBody) {
}
}

if (deployment instanceof Model) {
try {
context.setRequestBody(enhanceModelRequest(context));
} catch (Throwable e) {
context.respond(HttpStatus.BAD_REQUEST);
log.warn("Can't enhance model request: {}", e.getMessage());
}
}

sendRequest();
}

Expand Down Expand Up @@ -397,6 +417,25 @@ private static Map.Entry<Buffer, Map<String, String>> enhanceAssistantRequest(Pr
}
}

private static Buffer enhanceModelRequest(ProxyContext context) throws Exception {
Model model = (Model) context.getDeployment();
String overrideName = model.getOverrideName();
Buffer requestBody = context.getRequestBody();
if (overrideName == null) {
return requestBody;
}

try (InputStream stream = new ByteBufInputStream(requestBody.getByteBuf())) {
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream);

tree.remove("model");
tree.put("model", overrideName);

Buffer updatedBody = Buffer.buffer(ProxyUtil.MAPPER.writeValueAsBytes(tree));
return updatedBody;
}
}

private static ObjectNode insertPrompt(ArrayNode messages, String prompt) {
return messages.insertObject(0)
.put("role", "system")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
import com.epam.aidial.core.ProxyContext;
import com.epam.aidial.core.config.Application;
import com.epam.aidial.core.config.Config;
import com.epam.aidial.core.config.Model;
import com.epam.aidial.core.upstream.UpstreamBalancer;
import com.epam.aidial.core.upstream.UpstreamProvider;
import com.epam.aidial.core.upstream.UpstreamRoute;
import com.epam.aidial.core.util.ProxyUtil;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpClient;
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpServerRequest;
Expand All @@ -20,21 +24,28 @@
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import java.io.IOException;
import java.util.HashMap;
import java.util.Set;

import static com.epam.aidial.core.Proxy.HEADER_API_KEY;
import static com.epam.aidial.core.Proxy.HEADER_CONTENT_TYPE_APPLICATION_JSON;
import static com.epam.aidial.core.util.HttpStatus.BAD_GATEWAY;
import static com.epam.aidial.core.util.HttpStatus.FORBIDDEN;
import static com.epam.aidial.core.util.HttpStatus.NOT_FOUND;
import static com.epam.aidial.core.util.HttpStatus.UNSUPPORTED_MEDIA_TYPE;
import static io.vertx.core.http.HttpHeaders.AUTHORIZATION;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.assertArg;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.doCallRealMethod;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -72,14 +83,31 @@ public void testForbiddenDeployment() {
when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn(HEADER_CONTENT_TYPE_APPLICATION_JSON);
Config config = new Config();
config.setApplications(new HashMap<>());
config.getApplications().put("app1", new Application());
Application app = new Application();
app.setName("app1");
app.setUserRoles(Set.of("role1"));
config.getApplications().put("app1", app);
when(context.getConfig()).thenReturn(config);

controller.handle("app1", "api");
controller.handle("app1", "chat/completions");

verify(context).respond(eq(FORBIDDEN), anyString());
}

@Test
public void testDeploymentNotFound() {
when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn(HEADER_CONTENT_TYPE_APPLICATION_JSON);
Config config = new Config();
config.setApplications(new HashMap<>());
Application app = new Application();
config.getApplications().put("app1", app);
when(context.getConfig()).thenReturn(config);

controller.handle("unknown-app", "chat/completions");

verify(context).respond(eq(NOT_FOUND), anyString());
}

@Test
public void testNoRoute() {
when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn(HEADER_CONTENT_TYPE_APPLICATION_JSON);
Expand Down Expand Up @@ -189,5 +217,67 @@ public void testHandleProxyRequest_PropagateApiHeader() {
assertNull(proxyHeaders.get(HEADER_API_KEY));
}

@Test
public void testHandleRequestBody_OverrideModelName() throws IOException {
UpstreamRoute upstreamRoute = mock(UpstreamRoute.class, RETURNS_DEEP_STUBS);
when(upstreamRoute.hasNext()).thenReturn(true);
when(context.getUpstreamRoute()).thenReturn(upstreamRoute);
HttpServerRequest request = mock(HttpServerRequest.class, RETURNS_DEEP_STUBS);
when(context.getRequest()).thenReturn(request);
when(proxy.getClient()).thenReturn(mock(HttpClient.class, RETURNS_DEEP_STUBS));

Model model = new Model();
model.setName("name");
model.setEndpoint("http://host/model");
model.setOverrideName("overrideName");
when(context.getDeployment()).thenReturn(model);
String body = """
{
"model": "name"
}
""";
Buffer requestBody = Buffer.buffer(body);
when(context.getRequestBody()).thenCallRealMethod();
doCallRealMethod().when(context).setRequestBody(any());

controller.handleRequestBody(requestBody);

Buffer updatedBody = context.getRequestBody();
assertNotNull(updatedBody);

byte[] content = updatedBody.getBytes();
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(content);
assertEquals(tree.get("model").asText(), "overrideName");

}

@Test
public void testHandleRequestBody_NotOverrideModelName() {
UpstreamRoute upstreamRoute = mock(UpstreamRoute.class, RETURNS_DEEP_STUBS);
when(upstreamRoute.hasNext()).thenReturn(true);
when(context.getUpstreamRoute()).thenReturn(upstreamRoute);
HttpServerRequest request = mock(HttpServerRequest.class, RETURNS_DEEP_STUBS);
when(context.getRequest()).thenReturn(request);
when(proxy.getClient()).thenReturn(mock(HttpClient.class, RETURNS_DEEP_STUBS));

Model model = new Model();
model.setName("name");
model.setEndpoint("http://host/model");
when(context.getDeployment()).thenReturn(model);
String body = """
{
"model": "name"
}
""";
Buffer requestBody = Buffer.buffer(body);
when(context.getRequestBody()).thenCallRealMethod();
doCallRealMethod().when(context).setRequestBody(any());

controller.handleRequestBody(requestBody);

assertEquals(requestBody, context.getRequestBody());

}


}

0 comments on commit 6aafc45

Please sign in to comment.