Skip to content

Commit

Permalink
Merge branch 'development' into feat/issue-686
Browse files Browse the repository at this point in the history
  • Loading branch information
astsiapanay authored Feb 17, 2025
2 parents 9f942f3 + 7217984 commit e80855d
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.epam.aidial.core.config.Upstream;
import com.epam.aidial.core.server.Proxy;
import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.data.ApiKeyData;
import com.epam.aidial.core.server.data.ErrorData;
import com.epam.aidial.core.server.limiter.RateLimitResult;
import com.epam.aidial.core.server.upstream.UpstreamRoute;
Expand Down Expand Up @@ -45,14 +46,14 @@ public Future<?> handle() {
Route route = selectRoute();
if (route == null) {
log.warn("RouteController can't find a route to proceed the request: {}", getRequestUri());
context.respond(HttpStatus.BAD_GATEWAY, "No route");
respond(HttpStatus.BAD_GATEWAY, "No route");
return Future.succeededFuture();
}

if (!route.hasAccess(context.getUserRoles())) {
log.error("Forbidden route {}. Trace: {}. Span: {}. Project: {}. User sub: {}.",
route.getName(), context.getTraceId(), context.getSpanId(), context.getProject(), context.getUserSub());
context.respond(HttpStatus.FORBIDDEN, "Forbidden route");
respond(HttpStatus.FORBIDDEN, "Forbidden route");
return Future.succeededFuture();
}

Expand Down Expand Up @@ -105,6 +106,7 @@ private void handleRequestBody(Buffer requestBody) {
proxy.getRateLimiter().limit(context, context.getRoute())
.compose(rateLimitResult -> {
if (rateLimitResult.status() == HttpStatus.OK) {
setupProxyApiKeyData();
return sendRequest();
} else {
handleRateLimitHit(rateLimitResult);
Expand All @@ -118,6 +120,17 @@ private void handleRequestBody(Buffer requestBody) {
}
}

private void setupProxyApiKeyData() {
Upstream upstream = context.getUpstreamRoute().get();
if (upstream != null && upstream.getKey() != null) {
return;
}
ApiKeyData proxyApiKeyData = new ApiKeyData();
context.setProxyApiKeyData(proxyApiKeyData);
ApiKeyData.initFromContext(proxyApiKeyData, context);
proxy.getApiKeyStore().assignPerRequestApiKey(proxyApiKeyData);
}

/**
* Called when proxy connected to the origin.
*/
Expand All @@ -129,7 +142,12 @@ private void handleProxyRequest(HttpClientRequest proxyRequest) {

Upstream upstream = context.getUpstreamRoute().get();
ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers());
proxyRequest.putHeader(Proxy.HEADER_API_KEY, upstream.getKey());
if (upstream != null && upstream.getKey() != null) {
proxyRequest.putHeader(Proxy.HEADER_API_KEY, upstream.getKey());
} else {
ApiKeyData proxyApiKeyData = context.getProxyApiKeyData();
proxyRequest.headers().add(Proxy.HEADER_API_KEY, proxyApiKeyData.getPerRequestKey());
}

Buffer proxyRequestBody = context.getRequestBody();
proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(proxyRequestBody.length()));
Expand Down Expand Up @@ -185,7 +203,7 @@ private boolean canRetry(UpstreamRoute route) {
try {
route.next();
} catch (HttpException e) {
context.respond(e);
respond(e);
return false;
}
return true;
Expand Down Expand Up @@ -218,21 +236,21 @@ private void handleRateLimitHit(RateLimitResult result) {
httpException = new HttpException(result.status(), errorMessage);
}

context.respond(httpException);
respond(httpException);
}

private void handleError(Throwable error) {
String route = context.getRoute().getName();
log.error("Failed to handle route {}", route, error);
context.respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process route request: " + route);
respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process route request: " + route);
}

/**
* Called when proxy failed to receive request body from the client.
*/
private void handleRequestBodyError(Throwable error) {
log.warn("Failed to receive client body: {}", error.getMessage());
context.respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body");
respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body");
}

/**
Expand Down Expand Up @@ -270,6 +288,7 @@ private void handleResponseError(Throwable error) {
log.warn("Can't send response to client: {}", error.getMessage());
context.getProxyRequest().reset(); // drop connection to stop origin response
context.getResponse().reset(); // drop connection, so that partial client response won't seem complete
finalizeRequest();
}

private Route selectRoute() {
Expand Down Expand Up @@ -311,4 +330,26 @@ private String getEndpointUri(Upstream upstream) {
}
return uriBuilder.toString();
}

private void respond(HttpStatus status, String result) {
finalizeRequest();
context.respond(status, result);
}

private void respond(HttpException exception) {
finalizeRequest();
context.respond(exception);
}

private void finalizeRequest() {
ApiKeyData proxyApiKeyData = context.getProxyApiKeyData();
if (proxyApiKeyData != null) {
proxy.getApiKeyStore().invalidatePerRequestApiKey(proxyApiKeyData)
.onSuccess(invalidated -> {
if (!invalidated) {
log.warn("Per request is not removed: {}", proxyApiKeyData.getPerRequestKey());
}
}).onFailure(error -> log.error("error occurred on invalidating per-request key", error));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,15 @@ public static void initFromContext(ApiKeyData proxyApiKeyData, ProxyContext cont
proxyApiKeyData.setTraceId(apiKeyData.getTraceId());
currentPath = new ArrayList<>(context.getApiKeyData().getExecutionPath());
}
currentPath.add(context.getDeployment().getName());
if (context.getDeployment() != null) {
currentPath.add(context.getDeployment().getName());
proxyApiKeyData.setSourceDeployment(context.getDeployment().getName());
} else if (context.getRoute() != null) {
currentPath.add(context.getRoute().getName());
proxyApiKeyData.setSourceDeployment(context.getRoute().getName());
}
proxyApiKeyData.setExecutionPath(currentPath);
proxyApiKeyData.setSpanId(context.getSpanId());
proxyApiKeyData.setSourceDeployment(context.getDeployment().getName());
}

@JsonIgnore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
import java.util.UUID;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

class RouteApiTest extends ResourceBaseTest {

@ParameterizedTest
@MethodSource("datasource")
void route(HttpMethod method, String path, String apiKey, int expectedStatus, String expectedResponse) {
TestWebServer.Handler handler = request -> new MockResponse().setBody(request.getPath());
TestWebServer.Handler handler = request -> {
assertNotNull(request.getHeader(Proxy.HEADER_API_KEY));
return new MockResponse().setBody(request.getPath());
};
try (TestWebServer server = new TestWebServer(9876, handler)) {
String reqBody = (method == HttpMethod.POST) ? UUID.randomUUID().toString() : null;
Response resp = send(method, path, null, reqBody, "api-key", apiKey);
Expand Down

0 comments on commit e80855d

Please sign in to comment.