diff --git a/server/src/main/java/com/epam/aidial/core/server/controller/RouteController.java b/server/src/main/java/com/epam/aidial/core/server/controller/RouteController.java index 22ef61da..541de08d 100644 --- a/server/src/main/java/com/epam/aidial/core/server/controller/RouteController.java +++ b/server/src/main/java/com/epam/aidial/core/server/controller/RouteController.java @@ -5,6 +5,7 @@ import com.epam.aidial.core.config.Upstream; import com.epam.aidial.core.server.Proxy; import com.epam.aidial.core.server.ProxyContext; +import com.epam.aidial.core.server.data.ApiKeyData; import com.epam.aidial.core.server.data.ErrorData; import com.epam.aidial.core.server.limiter.RateLimitResult; import com.epam.aidial.core.server.upstream.UpstreamRoute; @@ -45,14 +46,14 @@ public Future handle() { Route route = selectRoute(); if (route == null) { log.warn("RouteController can't find a route to proceed the request: {}", getRequestUri()); - context.respond(HttpStatus.BAD_GATEWAY, "No route"); + respond(HttpStatus.BAD_GATEWAY, "No route"); return Future.succeededFuture(); } if (!route.hasAccess(context.getUserRoles())) { log.error("Forbidden route {}. Trace: {}. Span: {}. Project: {}. User sub: {}.", route.getName(), context.getTraceId(), context.getSpanId(), context.getProject(), context.getUserSub()); - context.respond(HttpStatus.FORBIDDEN, "Forbidden route"); + respond(HttpStatus.FORBIDDEN, "Forbidden route"); return Future.succeededFuture(); } @@ -105,6 +106,7 @@ private void handleRequestBody(Buffer requestBody) { proxy.getRateLimiter().limit(context, context.getRoute()) .compose(rateLimitResult -> { if (rateLimitResult.status() == HttpStatus.OK) { + setupProxyApiKeyData(); return sendRequest(); } else { handleRateLimitHit(rateLimitResult); @@ -118,6 +120,17 @@ private void handleRequestBody(Buffer requestBody) { } } + private void setupProxyApiKeyData() { + Upstream upstream = context.getUpstreamRoute().get(); + if (upstream != null && upstream.getKey() != null) { + return; + } + ApiKeyData proxyApiKeyData = new ApiKeyData(); + context.setProxyApiKeyData(proxyApiKeyData); + ApiKeyData.initFromContext(proxyApiKeyData, context); + proxy.getApiKeyStore().assignPerRequestApiKey(proxyApiKeyData); + } + /** * Called when proxy connected to the origin. */ @@ -129,7 +142,12 @@ private void handleProxyRequest(HttpClientRequest proxyRequest) { Upstream upstream = context.getUpstreamRoute().get(); ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers()); - proxyRequest.putHeader(Proxy.HEADER_API_KEY, upstream.getKey()); + if (upstream != null && upstream.getKey() != null) { + proxyRequest.putHeader(Proxy.HEADER_API_KEY, upstream.getKey()); + } else { + ApiKeyData proxyApiKeyData = context.getProxyApiKeyData(); + proxyRequest.headers().add(Proxy.HEADER_API_KEY, proxyApiKeyData.getPerRequestKey()); + } Buffer proxyRequestBody = context.getRequestBody(); proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(proxyRequestBody.length())); @@ -185,7 +203,7 @@ private boolean canRetry(UpstreamRoute route) { try { route.next(); } catch (HttpException e) { - context.respond(e); + respond(e); return false; } return true; @@ -218,13 +236,13 @@ private void handleRateLimitHit(RateLimitResult result) { httpException = new HttpException(result.status(), errorMessage); } - context.respond(httpException); + respond(httpException); } private void handleError(Throwable error) { String route = context.getRoute().getName(); log.error("Failed to handle route {}", route, error); - context.respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process route request: " + route); + respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process route request: " + route); } /** @@ -232,7 +250,7 @@ private void handleError(Throwable error) { */ private void handleRequestBodyError(Throwable error) { log.warn("Failed to receive client body: {}", error.getMessage()); - context.respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body"); + respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body"); } /** @@ -270,6 +288,7 @@ private void handleResponseError(Throwable error) { log.warn("Can't send response to client: {}", error.getMessage()); context.getProxyRequest().reset(); // drop connection to stop origin response context.getResponse().reset(); // drop connection, so that partial client response won't seem complete + finalizeRequest(); } private Route selectRoute() { @@ -311,4 +330,26 @@ private String getEndpointUri(Upstream upstream) { } return uriBuilder.toString(); } + + private void respond(HttpStatus status, String result) { + finalizeRequest(); + context.respond(status, result); + } + + private void respond(HttpException exception) { + finalizeRequest(); + context.respond(exception); + } + + private void finalizeRequest() { + ApiKeyData proxyApiKeyData = context.getProxyApiKeyData(); + if (proxyApiKeyData != null) { + proxy.getApiKeyStore().invalidatePerRequestApiKey(proxyApiKeyData) + .onSuccess(invalidated -> { + if (!invalidated) { + log.warn("Per request is not removed: {}", proxyApiKeyData.getPerRequestKey()); + } + }).onFailure(error -> log.error("error occurred on invalidating per-request key", error)); + } + } } diff --git a/server/src/main/java/com/epam/aidial/core/server/data/ApiKeyData.java b/server/src/main/java/com/epam/aidial/core/server/data/ApiKeyData.java index cb1bf08a..a8ed2d74 100644 --- a/server/src/main/java/com/epam/aidial/core/server/data/ApiKeyData.java +++ b/server/src/main/java/com/epam/aidial/core/server/data/ApiKeyData.java @@ -78,10 +78,15 @@ public static void initFromContext(ApiKeyData proxyApiKeyData, ProxyContext cont proxyApiKeyData.setTraceId(apiKeyData.getTraceId()); currentPath = new ArrayList<>(context.getApiKeyData().getExecutionPath()); } - currentPath.add(context.getDeployment().getName()); + if (context.getDeployment() != null) { + currentPath.add(context.getDeployment().getName()); + proxyApiKeyData.setSourceDeployment(context.getDeployment().getName()); + } else if (context.getRoute() != null) { + currentPath.add(context.getRoute().getName()); + proxyApiKeyData.setSourceDeployment(context.getRoute().getName()); + } proxyApiKeyData.setExecutionPath(currentPath); proxyApiKeyData.setSpanId(context.getSpanId()); - proxyApiKeyData.setSourceDeployment(context.getDeployment().getName()); } @JsonIgnore diff --git a/server/src/test/java/com/epam/aidial/core/server/RouteApiTest.java b/server/src/test/java/com/epam/aidial/core/server/RouteApiTest.java index 21f35dfe..a95fc29a 100644 --- a/server/src/test/java/com/epam/aidial/core/server/RouteApiTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/RouteApiTest.java @@ -12,6 +12,7 @@ import java.util.UUID; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; class RouteApiTest extends ResourceBaseTest { @@ -19,7 +20,10 @@ class RouteApiTest extends ResourceBaseTest { @ParameterizedTest @MethodSource("datasource") void route(HttpMethod method, String path, String apiKey, int expectedStatus, String expectedResponse) { - TestWebServer.Handler handler = request -> new MockResponse().setBody(request.getPath()); + TestWebServer.Handler handler = request -> { + assertNotNull(request.getHeader(Proxy.HEADER_API_KEY)); + return new MockResponse().setBody(request.getPath()); + }; try (TestWebServer server = new TestWebServer(9876, handler)) { String reqBody = (method == HttpMethod.POST) ? UUID.randomUUID().toString() : null; Response resp = send(method, path, null, reqBody, "api-key", apiKey);