Skip to content

Commit

Permalink
fix: propagate headers to feature endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Jan 10, 2025
1 parent b01f3e9 commit 6796c76
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package com.epam.aidial.core.server.controller;

import com.epam.aidial.core.config.Deployment;
import com.epam.aidial.core.config.Model;
import com.epam.aidial.core.config.Upstream;
import com.epam.aidial.core.server.Proxy;
import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.data.ApiKeyData;
import com.epam.aidial.core.server.service.PermissionDeniedException;
import com.epam.aidial.core.server.service.ResourceNotFoundException;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.server.vertx.stream.BufferingReadStream;
import com.epam.aidial.core.storage.http.HttpStatus;
import io.vertx.core.Future;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpClientResponse;
Expand Down Expand Up @@ -87,18 +91,39 @@ private void handleRequestError(String deploymentId, Throwable error) {
/**
* Called when proxy connected to the origin.
*/
private void handleProxyRequest(HttpClientRequest proxyRequest) {
log.info("Connected to origin: {}", proxyRequest.connection().remoteAddress());
void handleProxyRequest(HttpClientRequest proxyRequest) {
log.info("Connected to origin. Trace: {}. Span: {}. Project: {}. Deployment: {}. Address: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
proxyRequest.connection().remoteAddress());

HttpServerRequest request = context.getRequest();
context.setProxyRequest(proxyRequest);
context.setProxyConnectTimestamp(System.currentTimeMillis());

ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers());
Deployment deployment = context.getDeployment();
MultiMap excludeHeaders = MultiMap.caseInsensitiveMultiMap();
if (!deployment.isForwardAuthToken()) {
excludeHeaders.add(HttpHeaders.AUTHORIZATION, "whatever");
}

ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers(), excludeHeaders);

ApiKeyData proxyApiKeyData = context.getProxyApiKeyData();
proxyRequest.headers().add(Proxy.HEADER_API_KEY, proxyApiKeyData.getPerRequestKey());

if (context.getDeployment() instanceof Model model && !model.getUpstreams().isEmpty()) {
Upstream upstream = context.getUpstreamRoute().get();
proxyRequest.putHeader(Proxy.HEADER_UPSTREAM_ENDPOINT, upstream.getEndpoint());
proxyRequest.putHeader(Proxy.HEADER_UPSTREAM_KEY, upstream.getKey());
proxyRequest.putHeader(Proxy.HEADER_UPSTREAM_EXTRA_DATA, upstream.getExtraData());
}

Buffer proxyRequestBody = context.getRequestBody();
proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(proxyRequestBody.length()));
Buffer requestBody = context.getRequestBody();
proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(requestBody.length()));
context.getRequestHeaders().forEach(proxyRequest::putHeader);

proxyRequest.send(proxyRequestBody)
proxyRequest.send(requestBody)
.onSuccess(this::handleProxyResponse)
.onFailure(this::handleProxyRequestError);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

import io.vertx.core.http.HttpMethod;
import lombok.SneakyThrows;
import okhttp3.Headers;

import org.junit.jupiter.api.Test;

import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class FeaturesApiTest extends ResourceBaseTest {

Expand Down Expand Up @@ -66,9 +71,22 @@ void testUpstreamEndpoint(String inboundPath, String upstream) {
void testUpstreamEndpoint(String inboundPath, String upstream, HttpMethod method) {
URI uri = URI.create(upstream);
try (TestWebServer server = new TestWebServer(uri.getPort())) {
server.map(method, uri.getPath(), 200, "PONG");
Response response = send(method, inboundPath);
verify(response, 200, "PONG");
server.map(method, uri.getPath(), request -> TestWebServer.createResponse(200, "PONG", convertHeadersToFlatArray(request.getHeaders())));

Response response = send(method, inboundPath, null, "", "foo", "bar");
verify(response, 200, "PONG", "foo", "bar");
}
}

private static String[] convertHeadersToFlatArray(Headers headers) {
List<String> flatHeadersList = new ArrayList<>();
for (Map.Entry<String, List<String>> entry : headers.toMultimap().entrySet()) {
String key = entry.getKey();
for (String value : entry.getValue()) {
flatHeadersList.add(key);
flatHeadersList.add(value);
}
}
return flatHeadersList.toArray(new String[0]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,14 @@ static void verify(Response response, int status) {
assertEquals(status, response.status(), () -> "Actual response body: " + response.body());
}

static void verify(Response response, int status, String body) {
static void verify(Response response, int status, String body, String... headers) {
assertEquals(status, response.status(), () -> "Actual response body: " + response.body());
assertEquals(body, response.body());
for (int i = 0; i < headers.length; i += 2) {
String key = headers[i];
String value = headers[i + 1];
assertEquals(value, response.headers.get(key));
}
}

static void verifyJson(Response response, int status, String body) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ public void map(HttpMethod method, String path, int status) {
map(method, path, status, "");
}

public void map(HttpMethod method, String path, int status, String body) {
map(method, path, createResponse(status, body));
public void map(HttpMethod method, String path, int status, String body, String... headers) {
map(method, path, createResponse(status, body, headers));
}

private MockResponse onRequest(RecordedRequest request) {
Expand All @@ -67,10 +67,15 @@ private MockResponse onRequest(RecordedRequest request) {
return response;
}

private static MockResponse createResponse(int status, String body) {
public static MockResponse createResponse(int status, String body, String... headers) {
MockResponse response = new MockResponse();
response.setResponseCode(status);
response.setBody(body);
for (int i = 0; i < headers.length; i += 2) {
String key = headers[i];
String value = headers[i + 1];
response.setHeader(key, value);
}
return response;
}

Expand Down

0 comments on commit 6796c76

Please sign in to comment.