From a65de8bb006449c0ce09a486d25afd1f8a66768b Mon Sep 17 00:00:00 2001 From: Michael Froh Date: Mon, 13 Jan 2025 17:26:49 -0800 Subject: [PATCH] Ensure that cancelled request returns I tried cancelling a request and found that the client would hang. This change reports an exception for each remaining request instead. Signed-off-by: Michael Froh --- .../search/TransportMultiSearchAction.java | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java index c9054bd59b975..dcb2ce6eb88da 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java @@ -44,6 +44,7 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.tasks.TaskCancelledException; import org.opensearch.core.tasks.TaskId; import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; @@ -195,7 +196,20 @@ private void handleResponse(final int responseSlot, final MultiSearchResponse.It if (responseCounter.decrementAndGet() == 0) { assert requests.isEmpty(); finish(); - } else if (isCancelled(request.request.getParentTask()) == false) { + } else if (isCancelled(request.request.getParentTask())) { + // Drain the rest of the queue + SearchRequestSlot request; + while ((request = requests.poll()) != null) { + responses.set( + request.responseSlot, + new MultiSearchResponse.Item(null, new TaskCancelledException("Parent task was cancelled")) + ); + if (responseCounter.decrementAndGet() == 0) { + assert requests.isEmpty(); + finish(); + } + } + } else { if (thread == Thread.currentThread()) { // we are on the same thread, we need to fork to another thread to avoid recursive stack overflow on a single thread threadPool.generic()