From feb50d09ac2fc7fcbffa868eca9909e6e31c9b69 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Tue, 6 Jun 2023 14:36:01 -0700 Subject: [PATCH 01/10] Initial code for adding the SearchPhaseInjectorProcessor interface in Search Pipeline Signed-off-by: Navneet Verma --- CHANGELOG.md | 2 + .../search_pipeline/50_script_processor.yml | 2 +- .../search/AbstractSearchAsyncAction.java | 17 +- .../search/ArraySearchPhaseResults.java | 2 +- .../search/CanMatchPreFilterSearchPhase.java | 7 +- .../action/search/DfsQueryPhase.java | 2 +- .../action/search/ExpandSearchPhase.java | 2 +- .../action/search/FetchSearchPhase.java | 2 +- .../SearchDfsQueryThenFetchAsyncAction.java | 7 +- .../opensearch/action/search/SearchPhase.java | 33 ++- .../action/search/SearchPhaseContext.java | 2 +- .../action/search/SearchPhaseResults.java | 10 +- .../SearchQueryThenFetchAsyncAction.java | 7 +- .../search/SearchScrollAsyncAction.java | 2 +- ...SearchScrollQueryThenFetchAsyncAction.java | 2 +- .../action/search/TransportSearchAction.java | 12 +- .../plugins/SearchPipelinePlugin.java | 12 ++ .../opensearch/search/pipeline/Pipeline.java | 59 +++++- .../search/pipeline/PipelinedRequest.java | 12 ++ .../SearchPhaseInjectorProcessor.java | 37 ++++ .../pipeline/SearchPipelineService.java | 8 +- .../AbstractSearchAsyncActionTests.java | 5 +- .../CanMatchPreFilterSearchPhaseTests.java | 22 +- .../action/search/SearchAsyncActionTests.java | 19 +- .../SearchQueryThenFetchAsyncActionTests.java | 7 +- .../pipeline/SearchPipelineServiceTests.java | 199 +++++++++++++++++- 26 files changed, 440 insertions(+), 51 deletions(-) create mode 100644 server/src/main/java/org/opensearch/search/pipeline/SearchPhaseInjectorProcessor.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ceb36bd41ba3..3db3c51fe6767 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -90,6 +90,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Security ## [Unreleased 2.x] +### Added +- [SearchPipeline] Initial code for adding the SearchPhaseInjectorProcessor interface in Search Pipeline.([#7283](https://github.com/opensearch-project/OpenSearch/pull/7283)) ### Add TokenManager Interface ([#7452](https://github.com/opensearch-project/OpenSearch/pull/7452)) ### Dependencies diff --git a/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/50_script_processor.yml b/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/50_script_processor.yml index 9b2dc0c41ff31..9d855e8a1861a 100644 --- a/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/50_script_processor.yml +++ b/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/50_script_processor.yml @@ -39,7 +39,7 @@ teardown: { "script" : { "lang" : "painless", - "source" : "ctx._source['size'] += 10; ctx._source['from'] -= 1; ctx._source['explain'] = !ctx._source['explain']; ctx._source['version'] = !ctx._source['version']; ctx._source['seq_no_primary_term'] = !ctx._source['seq_no_primary_term']; ctx._source['track_scores'] = !ctx._source['track_scores']; ctx._source['track_total_hits'] = 1; ctx._source['min_score'] -= 0.9; ctx._source['terminate_after'] += 2; ctx._source['profile'] = !ctx._source['profile'];" + "source" : "ctx._source['size'] += 10; ctx._source['from'] = ctx._source['from'] <= 0 ? ctx._source['from'] : ctx._source['from'] - 1 ; ctx._source['explain'] = !ctx._source['explain']; ctx._source['version'] = !ctx._source['version']; ctx._source['seq_no_primary_term'] = !ctx._source['seq_no_primary_term']; ctx._source['track_scores'] = !ctx._source['track_scores']; ctx._source['track_total_hits'] = 1; ctx._source['min_score'] -= 0.9; ctx._source['terminate_after'] += 2; ctx._source['profile'] = !ctx._source['profile'];" } } ] diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index c2dd5f639db75..269a728394f2d 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -57,6 +57,8 @@ import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.pipeline.PipelinedRequest; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.transport.Transport; import java.util.ArrayDeque; @@ -116,6 +118,7 @@ abstract class AbstractSearchAsyncAction exten private final boolean throttleConcurrentRequests; private final List releasables = new ArrayList<>(); + private final SearchPipelineService searchPipelineService; AbstractSearchAsyncAction( String name, @@ -134,7 +137,8 @@ abstract class AbstractSearchAsyncAction exten SearchTask task, SearchPhaseResults resultConsumer, int maxConcurrentRequestsPerNode, - SearchResponse.Clusters clusters + SearchResponse.Clusters clusters, + SearchPipelineService searchPipelineService ) { super(name); final List toSkipIterators = new ArrayList<>(); @@ -170,6 +174,7 @@ abstract class AbstractSearchAsyncAction exten this.indexRoutings = indexRoutings; this.results = resultConsumer; this.clusters = clusters; + this.searchPipelineService = searchPipelineService; } @Override @@ -696,7 +701,15 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) { * @see #onShardResult(SearchPhaseResult, SearchShardIterator) */ final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() - executeNextPhase(this, getNextPhase(results, this)); + final SearchPhase nextPhase = getNextPhase(results, this); + // From src files the next phase is never null, but from tests this is a possibility. Hence, making sure that + // tests pass, we need to do null check on next phase. + if (nextPhase != null) { + + final PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(this.getRequest()); + pipelinedRequest.transformSearchPhase(results, this, this.getName(), nextPhase.getName()); + } + executeNextPhase(this, nextPhase); } @Override diff --git a/server/src/main/java/org/opensearch/action/search/ArraySearchPhaseResults.java b/server/src/main/java/org/opensearch/action/search/ArraySearchPhaseResults.java index 61c81e6cda97a..653b0e8aedb9d 100644 --- a/server/src/main/java/org/opensearch/action/search/ArraySearchPhaseResults.java +++ b/server/src/main/java/org/opensearch/action/search/ArraySearchPhaseResults.java @@ -66,7 +66,7 @@ boolean hasResult(int shardIndex) { } @Override - AtomicArray getAtomicArray() { + public AtomicArray getAtomicArray() { return results; } } diff --git a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java index 9694695e4fbbb..bece269902274 100644 --- a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java @@ -41,6 +41,7 @@ import org.opensearch.search.SearchShardTarget; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.MinAndMax; import org.opensearch.search.sort.SortOrder; @@ -90,7 +91,8 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction, SearchPhase> phaseFactory, - SearchResponse.Clusters clusters + SearchResponse.Clusters clusters, + SearchPipelineService searchPipelineService ) { // We set max concurrent shard requests to the number of shards so no throttling happens for can_match requests super( @@ -110,7 +112,8 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction, SearchPhase> nextPhaseFactory, SearchPhaseContext context ) { - super("dfs_query"); + super(SearchPhaseName.DFS_QUERY.getName()); this.progressListener = context.getTask().getProgressListener(); this.queryResult = queryResult; this.searchResults = searchResults; diff --git a/server/src/main/java/org/opensearch/action/search/ExpandSearchPhase.java b/server/src/main/java/org/opensearch/action/search/ExpandSearchPhase.java index cdefe7c2c1712..618a5620ce093 100644 --- a/server/src/main/java/org/opensearch/action/search/ExpandSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/ExpandSearchPhase.java @@ -62,7 +62,7 @@ final class ExpandSearchPhase extends SearchPhase { private final AtomicArray queryResults; ExpandSearchPhase(SearchPhaseContext context, InternalSearchResponse searchResponse, AtomicArray queryResults) { - super("expand"); + super(SearchPhaseName.EXPAND.getName()); this.context = context; this.searchResponse = searchResponse; this.queryResults = queryResults; diff --git a/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java index 31ec896856ce6..85a3d140977bb 100644 --- a/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java @@ -92,7 +92,7 @@ final class FetchSearchPhase extends SearchPhase { SearchPhaseContext context, BiFunction, SearchPhase> nextPhaseFactory ) { - super("fetch"); + super(SearchPhaseName.FETCH.getName()); if (context.getNumShards() != resultConsumer.getNumShards()) { throw new IllegalStateException( "number of shards must match the length of the query results but doesn't:" diff --git a/server/src/main/java/org/opensearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 71a986c0e15f7..422c10e222c2a 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -41,6 +41,7 @@ import org.opensearch.search.dfs.AggregatedDfs; import org.opensearch.search.dfs.DfsSearchResult; import org.opensearch.search.internal.AliasFilter; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.transport.Transport; import java.util.List; @@ -76,7 +77,8 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction final TransportSearchAction.SearchTimeProvider timeProvider, final ClusterState clusterState, final SearchTask task, - SearchResponse.Clusters clusters + SearchResponse.Clusters clusters, + SearchPipelineService searchPipelineService ) { super( "dfs", @@ -95,7 +97,8 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction task, new ArraySearchPhaseResults<>(shardsIts.size()), request.getMaxConcurrentShardRequests(), - clusters + clusters, + searchPipelineService ); this.queryPhaseResultConsumer = queryPhaseResultConsumer; this.searchPhaseController = searchPhaseController; diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhase.java b/server/src/main/java/org/opensearch/action/search/SearchPhase.java index 50f0940754078..1b009a983d2b0 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhase.java @@ -34,6 +34,7 @@ import org.opensearch.common.CheckedRunnable; import java.io.IOException; +import java.util.Locale; import java.util.Objects; /** @@ -41,7 +42,7 @@ * * @opensearch.internal */ -abstract class SearchPhase implements CheckedRunnable { +public abstract class SearchPhase implements CheckedRunnable { private final String name; protected SearchPhase(String name) { @@ -54,4 +55,34 @@ protected SearchPhase(String name) { public String getName() { return name; } + + /** + * Returns the SearchPhase name as {@link SearchPhaseName}. Exception will come if SearchPhase name is not defined + * in {@link SearchPhaseName} + * @return {@link SearchPhaseName} + */ + public SearchPhaseName getSearchPhaseName() { + return SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT)); + } + + /** + * Enum for different Search Phases in OpenSearch + * @opensearch.internal + */ + public enum SearchPhaseName { + QUERY("query"), + FETCH("fetch"), + DFS_QUERY("dfs_query"), + EXPAND("expand"); + + private final String name; + + SearchPhaseName(final String name) { + this.name = name; + } + + public String getName() { + return name; + } + } } diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java index 04d0dab088d35..4b609037ba907 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java @@ -50,7 +50,7 @@ * * @opensearch.internal */ -interface SearchPhaseContext extends Executor { +public interface SearchPhaseContext extends Executor { // TODO maybe we can make this concrete later - for now we just implement this in the base class for all initial phases /** diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseResults.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseResults.java index 1baea0e721c44..2e6068b1ecddc 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseResults.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseResults.java @@ -42,7 +42,7 @@ * * @opensearch.internal */ -abstract class SearchPhaseResults { +public abstract class SearchPhaseResults { private final int numShards; SearchPhaseResults(int numShards) { @@ -75,7 +75,13 @@ final int getNumShards() { void consumeShardFailure(int shardIndex) {} - AtomicArray getAtomicArray() { + /** + * Returns an {@link AtomicArray} of {@link Result}, which are nothing but the SearchPhaseResults + * for shards. The {@link Result} are of type {@link SearchPhaseResult} + * + * @return an {@link AtomicArray} of {@link Result} + */ + public AtomicArray getAtomicArray() { throw new UnsupportedOperationException(); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java index 1ead14aac6b51..2aaa1d788c5bc 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -42,6 +42,7 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.transport.Transport; @@ -81,7 +82,8 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction fetchResults ) { - return new SearchPhase("fetch") { + return new SearchPhase(SearchPhase.SearchPhaseName.FETCH.getName()) { @Override public void run() throws IOException { sendResponse(queryPhase, fetchResults); diff --git a/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java index 4119cb1cf28a0..51ffeb2ac83bc 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java @@ -92,7 +92,7 @@ protected void executeInitialPhase( @Override protected SearchPhase moveToNextPhase(BiFunction clusterNodeLookup) { - return new SearchPhase("fetch") { + return new SearchPhase(SearchPhase.SearchPhaseName.FETCH.getName()) { @Override public void run() { final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedScrollQueryPhase( diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 69f529fe1d00c..fe7fc2d7ee383 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -346,7 +346,8 @@ public AbstractSearchAsyncAction asyncSearchAction( task, new ArraySearchPhaseResults<>(shardsIts.size()), searchRequest.getMaxConcurrentShardRequests(), - clusters + clusters, + searchPipelineService ) { @Override protected void executePhaseOnShard( @@ -1163,7 +1164,8 @@ public void run() { } }; }, - clusters + clusters, + searchPipelineService ); } else { final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults( @@ -1193,7 +1195,8 @@ public void run() { timeProvider, clusterState, task, - clusters + clusters, + searchPipelineService ); break; case QUERY_THEN_FETCH: @@ -1213,7 +1216,8 @@ public void run() { timeProvider, clusterState, task, - clusters + clusters, + searchPipelineService ); break; default: diff --git a/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java b/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java index b8ceddecd3d20..f9070d14aa296 100644 --- a/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java +++ b/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java @@ -9,6 +9,7 @@ package org.opensearch.plugins; import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseInjectorProcessor; import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.search.pipeline.SearchResponseProcessor; @@ -42,4 +43,15 @@ default Map> getRequestProcess default Map> getResponseProcessors(Processor.Parameters parameters) { return Collections.emptyMap(); } + + /** + * Returns additional search pipeline search phase injector processor types added by this plugin. + * + * The key of the returned {@link Map} is the unique name for the processor which is specified + * in pipeline configurations, and the value is a {@link org.opensearch.search.pipeline.Processor.Factory} + * to create the processor from a given pipeline configuration. + */ + default Map> getPhaseInjectorProcessors(Processor.Parameters parameters) { + return Collections.emptyMap(); + } } diff --git a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java index c9a5f865d507e..8ae74c9324ab2 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java +++ b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java @@ -9,6 +9,8 @@ package org.opensearch.search.pipeline; import org.opensearch.OpenSearchParseException; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.common.Nullable; @@ -17,6 +19,7 @@ import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.search.SearchPhaseResult; import java.util.ArrayList; import java.util.Arrays; @@ -35,6 +38,7 @@ class Pipeline { public static final String REQUEST_PROCESSORS_KEY = "request_processors"; public static final String RESPONSE_PROCESSORS_KEY = "response_processors"; + public static final String PHASE_PROCESSORS_KEY = "phase_injector_processors"; private final String id; private final String description; private final Integer version; @@ -43,8 +47,8 @@ class Pipeline { // Then these can be CompoundProcessors instead of lists. private final List searchRequestProcessors; private final List searchResponseProcessors; - private final NamedWriteableRegistry namedWriteableRegistry; + private final List searchPhaseInjectorProcessors; private Pipeline( String id, @@ -52,6 +56,7 @@ private Pipeline( @Nullable Integer version, List requestProcessors, List responseProcessors, + List phaseInjectorProcessors, NamedWriteableRegistry namedWriteableRegistry ) { this.id = id; @@ -59,6 +64,7 @@ private Pipeline( this.version = version; this.searchRequestProcessors = requestProcessors; this.searchResponseProcessors = responseProcessors; + this.searchPhaseInjectorProcessors = phaseInjectorProcessors; this.namedWriteableRegistry = namedWriteableRegistry; } @@ -67,6 +73,7 @@ static Pipeline create( Map config, Map> requestProcessorFactories, Map> responseProcessorFactories, + Map> phaseInjectorProcessorFactories, NamedWriteableRegistry namedWriteableRegistry ) throws Exception { String description = ConfigurationUtils.readOptionalStringProperty(null, null, config, DESCRIPTION_KEY); @@ -79,7 +86,16 @@ static Pipeline create( config, RESPONSE_PROCESSORS_KEY ); - List responseProcessors = readProcessors(responseProcessorFactories, responseProcessorConfigs); + + final List> phaseProcessorConfigs = ConfigurationUtils.readOptionalList( + null, + null, + config, + PHASE_PROCESSORS_KEY + ); + final List responseProcessors = readProcessors(responseProcessorFactories, responseProcessorConfigs); + final List phaseProcessors = readProcessors(phaseInjectorProcessorFactories, phaseProcessorConfigs); + if (config.isEmpty() == false) { throw new OpenSearchParseException( "pipeline [" @@ -88,7 +104,7 @@ static Pipeline create( + Arrays.toString(config.keySet().toArray()) ); } - return new Pipeline(id, description, version, requestProcessors, responseProcessors, namedWriteableRegistry); + return new Pipeline(id, description, version, requestProcessors, responseProcessors, phaseProcessors, namedWriteableRegistry); } private static List readProcessors( @@ -111,7 +127,17 @@ private static List readProcessors( processors.add(processorFactories.get(type).create(processorFactories, tag, description, config)); } } - return Collections.unmodifiableList(processors); + return processors; + } + + List flattenAllProcessors() { + List allProcessors = new ArrayList<>( + searchRequestProcessors.size() + searchResponseProcessors.size() + searchPhaseInjectorProcessors.size() + ); + allProcessors.addAll(searchRequestProcessors); + allProcessors.addAll(searchPhaseInjectorProcessors); + allProcessors.addAll(searchResponseProcessors); + return allProcessors; } String getId() { @@ -134,6 +160,10 @@ List getSearchResponseProcessors() { return searchResponseProcessors; } + List getSearchPhaseInjectorProcessors() { + return searchPhaseInjectorProcessors; + } + SearchRequest transformRequest(SearchRequest request) throws Exception { if (searchRequestProcessors.isEmpty() == false) { try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()) { @@ -168,6 +198,27 @@ SearchResponse transformResponse(SearchRequest request, SearchResponse response) 0, Collections.emptyList(), Collections.emptyList(), + Collections.emptyList(), null ); + + SearchPhaseResults runSearchPhaseTransformer( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext context, + String currentPhase, + String nextPhase + ) throws SearchPipelineProcessingException { + + try { + for (SearchPhaseInjectorProcessor searchPhaseInjectorProcessor : searchPhaseInjectorProcessors) { + if (currentPhase.equals(searchPhaseInjectorProcessor.getBeforePhase().getName()) + && nextPhase.equals(searchPhaseInjectorProcessor.getAfterPhase().getName())) { + searchPhaseResult = searchPhaseInjectorProcessor.execute(searchPhaseResult, context); + } + } + return searchPhaseResult; + } catch (Exception e) { + throw new SearchPipelineProcessingException(e); + } + } } diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java index 0cfff013f4021..e45b510d7c760 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java @@ -8,8 +8,11 @@ package org.opensearch.search.pipeline; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.search.SearchPhaseResult; /** * Groups a search pipeline based on a request and the request after being transformed by the pipeline. @@ -33,6 +36,15 @@ public SearchRequest transformedRequest() { return transformedRequest; } + public SearchPhaseResults transformSearchPhase( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final String currentPhase, + final String nextPhase + ) { + return pipeline.runSearchPhaseTransformer(searchPhaseResult, searchPhaseContext, currentPhase, nextPhase); + } + // Visible for testing Pipeline getPipeline() { return pipeline; diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseInjectorProcessor.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseInjectorProcessor.java new file mode 100644 index 0000000000000..0d4f7950596cf --- /dev/null +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseInjectorProcessor.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline; + +import org.opensearch.action.search.SearchPhase; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.search.SearchPhaseResult; + +/** + * Creates a processor that runs between Phases of the Search. + */ +public interface SearchPhaseInjectorProcessor extends Processor { + SearchPhaseResults execute( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ); + + /** + * The phase which should have run before, this processor can start executing. + * @return {@link SearchPhase.SearchPhaseName} + */ + SearchPhase.SearchPhaseName getBeforePhase(); + + /** + * The phase which should run after, this processor execution. + * @return {@link SearchPhase.SearchPhaseName} + */ + SearchPhase.SearchPhaseName getAfterPhase(); + +} diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java index 87c09bd971284..824253fa9294c 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java @@ -72,6 +72,8 @@ public class SearchPipelineService implements ClusterStateApplier, ReportingServ private final ScriptService scriptService; private final Map> requestProcessorFactories; private final Map> responseProcessorFactories; + + private final Map> phaseInjectorProcessorFactories; private volatile Map pipelines = Collections.emptyMap(); private final ThreadPool threadPool; private final List> searchPipelineClusterStateListeners = new CopyOnWriteArrayList<>(); @@ -112,6 +114,7 @@ public SearchPipelineService( ); this.requestProcessorFactories = processorFactories(searchPipelinePlugins, p -> p.getRequestProcessors(parameters)); this.responseProcessorFactories = processorFactories(searchPipelinePlugins, p -> p.getResponseProcessors(parameters)); + this.phaseInjectorProcessorFactories = processorFactories(searchPipelinePlugins, p -> p.getPhaseInjectorProcessors(parameters)); putPipelineTaskKey = clusterService.registerClusterManagerTask(ClusterManagerTaskKeys.PUT_SEARCH_PIPELINE_KEY, true); deletePipelineTaskKey = clusterService.registerClusterManagerTask(ClusterManagerTaskKeys.DELETE_SEARCH_PIPELINE_KEY, true); this.isEnabled = isEnabled; @@ -177,6 +180,7 @@ void innerUpdatePipelines(SearchPipelineMetadata newSearchPipelineMetadata) { newConfiguration.getConfigAsMap(), requestProcessorFactories, responseProcessorFactories, + phaseInjectorProcessorFactories, namedWriteableRegistry ); newPipelines.put(newConfiguration.getId(), new PipelineHolder(newConfiguration, newPipeline)); @@ -276,6 +280,7 @@ void validatePipeline(Map searchPipelineInfos pipelineConfig, requestProcessorFactories, responseProcessorFactories, + phaseInjectorProcessorFactories, namedWriteableRegistry ); List exceptions = new ArrayList<>(); @@ -353,7 +358,7 @@ static ClusterState innerDelete(DeleteSearchPipelineRequest request, ClusterStat return newState.build(); } - public PipelinedRequest resolvePipeline(SearchRequest searchRequest) throws Exception { + public PipelinedRequest resolvePipeline(SearchRequest searchRequest) { Pipeline pipeline = Pipeline.NO_OP_PIPELINE; if (isEnabled == false) { @@ -372,6 +377,7 @@ public PipelinedRequest resolvePipeline(SearchRequest searchRequest) throws Exce searchRequest.source().searchPipelineSource(), requestProcessorFactories, responseProcessorFactories, + phaseInjectorProcessorFactories, namedWriteableRegistry ); } catch (Exception e) { diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index ad2657517df9a..f55ce93d019fe 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -52,6 +52,7 @@ import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; @@ -83,6 +84,7 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase { private final List> resolvedNodes = new ArrayList<>(); private final Set releasedContexts = new CopyOnWriteArraySet<>(); private ExecutorService executor; + private SearchPipelineService searchPipelineService; @Before @Override @@ -161,7 +163,8 @@ private AbstractSearchAsyncAction createAction( null, results, request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { @Override protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { diff --git a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java index 2e3ff166a6a53..9876dbdf6f90b 100644 --- a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -32,6 +32,7 @@ package org.opensearch.action.search; import org.apache.lucene.util.BytesRef; +import org.mockito.Mockito; import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.OriginalIndices; @@ -47,6 +48,7 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.sort.MinAndMax; import org.opensearch.search.sort.SortBuilders; import org.opensearch.search.sort.SortOrder; @@ -72,6 +74,8 @@ public class CanMatchPreFilterSearchPhaseTests extends OpenSearchTestCase { + private final SearchPipelineService searchPipelineService = Mockito.mock(SearchPipelineService.class); + public void testFilterShards() throws InterruptedException { final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( @@ -136,7 +140,8 @@ public void run() throws IOException { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ); canMatchPhase.start(); @@ -227,7 +232,8 @@ public void run() throws IOException { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ); canMatchPhase.start(); @@ -317,7 +323,8 @@ public void sendCanMatch( null, new ArraySearchPhaseResults<>(iter.size()), randomIntBetween(1, 32), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { @Override @@ -344,7 +351,8 @@ protected void executePhaseOnShard( } } }, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ); canMatchPhase.start(); @@ -428,7 +436,8 @@ public void run() { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ); canMatchPhase.start(); @@ -527,7 +536,8 @@ public void run() { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ); canMatchPhase.start(); diff --git a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java index cf838682aa717..521778dcbf171 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java @@ -31,6 +31,7 @@ package org.opensearch.action.search; +import org.mockito.Mockito; import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.OriginalIndices; @@ -50,6 +51,7 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportException; @@ -78,6 +80,8 @@ public class SearchAsyncActionTests extends OpenSearchTestCase { + private SearchPipelineService searchPipelineService = Mockito.mock(SearchPipelineService.class); + public void testSkipSearchShards() throws InterruptedException { SearchRequest request = new SearchRequest(); request.allowPartialSearchResults(true); @@ -135,7 +139,8 @@ public void testSkipSearchShards() throws InterruptedException { null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { @Override @@ -253,7 +258,8 @@ public void testLimitConcurrentShardRequests() throws InterruptedException { null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { @Override @@ -370,7 +376,8 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { TestSearchResponse response = new TestSearchResponse(); @@ -492,7 +499,8 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { TestSearchResponse response = new TestSearchResponse(); @@ -605,7 +613,8 @@ public void testAllowPartialResults() throws InterruptedException { null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { @Override diff --git a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java index 4e351e1424cd0..3649ee554c197 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -37,6 +37,7 @@ import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.grouping.CollapseTopFieldDocs; +import org.mockito.Mockito; import org.opensearch.Version; import org.opensearch.action.OriginalIndices; import org.opensearch.cluster.node.DiscoveryNode; @@ -56,6 +57,7 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.sort.SortBuilders; import org.opensearch.test.OpenSearchTestCase; @@ -75,6 +77,8 @@ import static org.hamcrest.Matchers.instanceOf; public class SearchQueryThenFetchAsyncActionTests extends OpenSearchTestCase { + private final SearchPipelineService searchPipelineService = Mockito.mock(SearchPipelineService.class); + public void testBottomFieldSort() throws Exception { testCase(false, false); } @@ -214,7 +218,8 @@ public void sendExecuteQuery( timeProvider, null, task, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { @Override protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index d49d9fd41031c..c7a2ae98624a3 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -10,13 +10,22 @@ import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.opensearch.OpenSearchParseException; import org.opensearch.ResourceNotFoundException; import org.opensearch.Version; import org.opensearch.action.search.DeleteSearchPipelineRequest; +import org.opensearch.action.search.MockSearchPhaseContext; import org.opensearch.action.search.PutSearchPipelineRequest; +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhase; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseController; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.action.search.SearchProgressListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; @@ -28,8 +37,11 @@ import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.breaker.CircuitBreaker; +import org.opensearch.common.breaker.NoopCircuitBreaker; import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.xcontent.XContentType; @@ -39,7 +51,10 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchModule; +import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.InternalAggregationTestCase; import org.opensearch.test.MockLogAppender; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -67,6 +82,11 @@ public Map> getRequestProcesso public Map> getResponseProcessors(Processor.Parameters parameters) { return Map.of("bar", (factories, tag, description, config) -> null); } + + @Override + public Map> getPhaseInjectorProcessors(Processor.Parameters parameters) { + return Map.of("zoe", (factories, tag, description, config) -> null); + } }; private ThreadPool threadPool; @@ -243,6 +263,41 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp } } + private static class FakeSearchPhaseInjectorProcessor extends FakeProcessor implements SearchPhaseInjectorProcessor { + private Consumer querySearchResultConsumer; + + public FakeSearchPhaseInjectorProcessor( + String type, + String tag, + String description, + Consumer querySearchResultConsumer + ) { + super(type, tag, description); + this.querySearchResultConsumer = querySearchResultConsumer; + } + + @Override + public SearchPhaseResults execute( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext + ) { + List resultAtomicArray = searchPhaseResult.getAtomicArray().asList(); + // updating the maxScore + resultAtomicArray.forEach(querySearchResultConsumer); + return searchPhaseResult; + } + + @Override + public SearchPhase.SearchPhaseName getBeforePhase() { + return SearchPhase.SearchPhaseName.QUERY; + } + + @Override + public SearchPhase.SearchPhaseName getAfterPhase() { + return SearchPhase.SearchPhaseName.FETCH; + } + } + private SearchPipelineService createWithProcessors() { Map> requestProcessors = new HashMap<>(); requestProcessors.put("scale_request_size", (processorFactories, tag, description, config) -> { @@ -259,7 +314,15 @@ private SearchPipelineService createWithProcessors() { float score = ((Number) config.remove("score")).floatValue(); return new FakeResponseProcessor("fixed_score", tag, description, rsp -> rsp.getHits().forEach(h -> h.score(score))); }); - return createWithProcessors(requestProcessors, responseProcessors); + + Map> searchPhaseProcessors = new HashMap<>(); + searchPhaseProcessors.put("max_score", (processorFactories, tag, description, config) -> { + final float finalScore = config.containsKey("score") ? ((Number) config.remove("score")).floatValue() : 100f; + final Consumer querySearchResultConsumer = (result) -> result.queryResult().topDocs().maxScore = finalScore; + return new FakeSearchPhaseInjectorProcessor("max_score", tag, description, querySearchResultConsumer); + }); + + return createWithProcessors(requestProcessors, responseProcessors, searchPhaseProcessors); } @Override @@ -270,7 +333,8 @@ protected NamedWriteableRegistry writableRegistry() { private SearchPipelineService createWithProcessors( Map> requestProcessors, - Map> responseProcessors + Map> responseProcessors, + Map> phaseProcessors ) { Client client = mock(Client.class); ThreadPool threadPool = mock(ThreadPool.class); @@ -295,6 +359,14 @@ public Map> getRequestProcesso public Map> getResponseProcessors(Processor.Parameters parameters) { return responseProcessors; } + + @Override + public Map> getPhaseInjectorProcessors( + Processor.Parameters parameters + ) { + return phaseProcessors; + } + }), client, true @@ -313,7 +385,8 @@ public void testUpdatePipelines() { new BytesArray( "{ " + "\"request_processors\" : [ { \"scale_request_size\": { \"scale\" : 2 } } ], " - + "\"response_processors\" : [ { \"fixed_score\" : { \"score\" : 1.0 } } ]" + + "\"response_processors\" : [ { \"fixed_score\" : { \"score\" : 1.0 } } ]," + + "\"phase_injector_processors\" : [ { \"max_score\" : { \"score\": 100 } } ]" + "}" ), XContentType.JSON @@ -331,6 +404,11 @@ public void testUpdatePipelines() { "scale_request_size", searchPipelineService.getPipelines().get("_id").pipeline.getSearchRequestProcessors().get(0).getType() ); + assertEquals(1, searchPipelineService.getPipelines().get("_id").pipeline.getSearchPhaseInjectorProcessors().size()); + assertEquals( + "max_score", + searchPipelineService.getPipelines().get("_id").pipeline.getSearchPhaseInjectorProcessors().get(0).getType() + ); assertEquals(1, searchPipelineService.getPipelines().get("_id").pipeline.getSearchResponseProcessors().size()); assertEquals( "fixed_score", @@ -368,6 +446,7 @@ public void testPutPipeline() { assertEquals("empty pipeline", pipeline.pipeline.getDescription()); assertEquals(0, pipeline.pipeline.getSearchRequestProcessors().size()); assertEquals(0, pipeline.pipeline.getSearchResponseProcessors().size()); + assertEquals(0, pipeline.pipeline.getSearchPhaseInjectorProcessors().size()); } public void testPutInvalidPipeline() throws IllegalAccessException { @@ -564,6 +643,87 @@ public void testTransformResponse() throws Exception { } } + public void testTransformSearchPhase() { + SearchPipelineService searchPipelineService = createWithProcessors(); + SearchPipelineMetadata metadata = new SearchPipelineMetadata( + Map.of( + "p1", + new PipelineConfiguration( + "p1", + new BytesArray("{\"phase_injector_processors\" : [ { \"max_score\" : { } } ]}"), + XContentType.JSON + ) + ) + ); + ClusterState clusterState = ClusterState.builder(new ClusterName("_name")).build(); + ClusterState previousState = clusterState; + clusterState = ClusterState.builder(clusterState) + .metadata(Metadata.builder().putCustom(SearchPipelineMetadata.TYPE, metadata)) + .build(); + searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState)); + SearchPhaseController controller = new SearchPhaseController( + writableRegistry(), + s -> InternalAggregationTestCase.emptyReduceContextBuilder() + ); + SearchPhaseContext searchPhaseContext = new MockSearchPhaseContext(10); + QueryPhaseResultConsumer searchPhaseResults = new QueryPhaseResultConsumer( + searchPhaseContext.getRequest(), + OpenSearchExecutors.newDirectExecutorService(), + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + controller, + SearchProgressListener.NOOP, + writableRegistry(), + 2, + exc -> {} + ); + + final QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.setShardIndex(1); + querySearchResult.topDocs(new TopDocsAndMaxScore(new TopDocs(null, new ScoreDoc[1]), 1f), null); + searchPhaseResults.consumeResult(querySearchResult, () -> {}); + + // First try without specifying a pipeline, which should be a no-op. + SearchRequest searchRequest = new SearchRequest(); + PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(searchRequest); + SearchPhaseResults notTransformedSearchPhaseResults = pipelinedRequest.transformSearchPhase( + searchPhaseResults, + searchPhaseContext, + SearchPhase.SearchPhaseName.QUERY.getName(), + SearchPhase.SearchPhaseName.FETCH.getName() + ); + assertSame(searchPhaseResults, notTransformedSearchPhaseResults); + + // Now set the pipeline as p1 + searchRequest = new SearchRequest().pipeline("p1"); + pipelinedRequest = searchPipelineService.resolvePipeline(searchRequest); + + SearchPhaseResults transformed = pipelinedRequest.transformSearchPhase( + searchPhaseResults, + searchPhaseContext, + SearchPhase.SearchPhaseName.QUERY.getName(), + SearchPhase.SearchPhaseName.FETCH.getName() + ); + + List resultAtomicArray = transformed.getAtomicArray().asList(); + assertEquals(1, resultAtomicArray.size()); + // updating the maxScore + for (SearchPhaseResult result : resultAtomicArray) { + assertEquals(100f, result.queryResult().topDocs().maxScore, 0); + } + + // Check Processor doesn't run for between other phases + searchRequest = new SearchRequest().pipeline("p1"); + pipelinedRequest = searchPipelineService.resolvePipeline(searchRequest); + SearchPhaseResults notTransformed = pipelinedRequest.transformSearchPhase( + searchPhaseResults, + searchPhaseContext, + SearchPhase.SearchPhaseName.DFS_QUERY.getName(), + SearchPhase.SearchPhaseName.QUERY.getName() + ); + + assertSame(searchPhaseResults, notTransformed); + } + public void testGetPipelines() { // assertEquals(0, SearchPipelineService.innerGetPipelines(null, "p1").size()); @@ -581,16 +741,23 @@ public void testGetPipelines() { "p2", new BytesArray("{\"response_processors\" : [ { \"fixed_score\": { \"score\" : 2 } } ] }"), XContentType.JSON + ), + "p3", + new PipelineConfiguration( + "p3", + new BytesArray("{\"phase_injector_processors\" : [ { \"max_score\" : { } } ]}"), + XContentType.JSON ) ) ); // Return all when no ids specified List pipelines = SearchPipelineService.innerGetPipelines(metadata); - assertEquals(2, pipelines.size()); + assertEquals(3, pipelines.size()); pipelines.sort(Comparator.comparing(PipelineConfiguration::getId)); assertEquals("p1", pipelines.get(0).getId()); assertEquals("p2", pipelines.get(1).getId()); + assertEquals("p3", pipelines.get(2).getId()); // Get specific pipeline pipelines = SearchPipelineService.innerGetPipelines(metadata, "p1"); @@ -606,17 +773,19 @@ public void testGetPipelines() { // Match all pipelines = SearchPipelineService.innerGetPipelines(metadata, "*"); - assertEquals(2, pipelines.size()); + assertEquals(3, pipelines.size()); pipelines.sort(Comparator.comparing(PipelineConfiguration::getId)); assertEquals("p1", pipelines.get(0).getId()); assertEquals("p2", pipelines.get(1).getId()); + assertEquals("p3", pipelines.get(2).getId()); // Match prefix pipelines = SearchPipelineService.innerGetPipelines(metadata, "p*"); - assertEquals(2, pipelines.size()); + assertEquals(3, pipelines.size()); pipelines.sort(Comparator.comparing(PipelineConfiguration::getId)); assertEquals("p1", pipelines.get(0).getId()); assertEquals("p2", pipelines.get(1).getId()); + assertEquals("p3", pipelines.get(2).getId()); } public void testValidatePipeline() throws Exception { @@ -624,6 +793,7 @@ public void testValidatePipeline() throws Exception { ProcessorInfo reqProcessor = new ProcessorInfo("scale_request_size"); ProcessorInfo rspProcessor = new ProcessorInfo("fixed_score"); + ProcessorInfo injProcessor = new ProcessorInfo("max_score"); DiscoveryNode n1 = new DiscoveryNode("n1", buildNewFakeTransportAddress(), Version.CURRENT); DiscoveryNode n2 = new DiscoveryNode("n2", buildNewFakeTransportAddress(), Version.CURRENT); PutSearchPipelineRequest putRequest = new PutSearchPipelineRequest( @@ -631,7 +801,8 @@ public void testValidatePipeline() throws Exception { new BytesArray( "{" + "\"request_processors\": [{ \"scale_request_size\": { \"scale\" : 2 } }]," - + "\"response_processors\": [{ \"fixed_score\": { \"score\" : 2 } }]" + + "\"response_processors\": [{ \"fixed_score\": { \"score\" : 2 } }]," + + "\"phase_injector_processors\" : [ { \"max_score\" : { } } ]" + "}" ), XContentType.JSON @@ -729,7 +900,7 @@ public void testExceptionOnPipelineCreation() { "bad_factory", (pf, t, f, c) -> { throw new RuntimeException(); } ); - SearchPipelineService searchPipelineService = createWithProcessors(badFactory, Collections.emptyMap()); + SearchPipelineService searchPipelineService = createWithProcessors(badFactory, Collections.emptyMap(), Collections.emptyMap()); Map pipelineSourceMap = new HashMap<>(); pipelineSourceMap.put(Pipeline.REQUEST_PROCESSORS_KEY, List.of(Map.of("bad_factory", Collections.emptyMap()))); @@ -751,7 +922,11 @@ public void testExceptionOnRequestProcessing() { (pf, t, f, c) -> throwingRequestProcessor ); - SearchPipelineService searchPipelineService = createWithProcessors(throwingRequestProcessorFactory, Collections.emptyMap()); + SearchPipelineService searchPipelineService = createWithProcessors( + throwingRequestProcessorFactory, + Collections.emptyMap(), + Collections.emptyMap() + ); Map pipelineSourceMap = new HashMap<>(); pipelineSourceMap.put(Pipeline.REQUEST_PROCESSORS_KEY, List.of(Map.of("throwing_request", Collections.emptyMap()))); @@ -772,7 +947,11 @@ public void testExceptionOnResponseProcessing() throws Exception { (pf, t, f, c) -> throwingResponseProcessor ); - SearchPipelineService searchPipelineService = createWithProcessors(Collections.emptyMap(), throwingResponseProcessorFactory); + SearchPipelineService searchPipelineService = createWithProcessors( + Collections.emptyMap(), + throwingResponseProcessorFactory, + Collections.emptyMap() + ); Map pipelineSourceMap = new HashMap<>(); pipelineSourceMap.put(Pipeline.RESPONSE_PROCESSORS_KEY, List.of(Map.of("throwing_response", Collections.emptyMap()))); From f68b189480f66f92b5e661f1fda99eb30cab1284 Mon Sep 17 00:00:00 2001 From: Michael Froh Date: Wed, 7 Jun 2023 22:56:37 +0000 Subject: [PATCH 02/10] Pass PipelinedRequest to SearchAsyncActions We should resolve a search pipeline once at the start of a search request and then propagate that pipeline through the async actions. When completing a search phase, we will then use that pipeline to inject behavior (if applicable). Signed-off-by: Michael Froh --- .../search/AbstractSearchAsyncAction.java | 46 ++++++------- .../search/CanMatchPreFilterSearchPhase.java | 10 ++- .../SearchDfsQueryThenFetchAsyncAction.java | 14 ++-- .../SearchQueryThenFetchAsyncAction.java | 20 +++--- .../action/search/TransportSearchAction.java | 67 +++++++++---------- .../search/pipeline/PipelinedRequest.java | 17 +++++ .../AbstractSearchAsyncActionTests.java | 8 +-- .../CanMatchPreFilterSearchPhaseTests.java | 35 ++++------ .../action/search/SearchAsyncActionTests.java | 26 +++---- .../SearchQueryThenFetchAsyncActionTests.java | 9 +-- .../search/TransportSearchActionTests.java | 49 +++++++------- 11 files changed, 148 insertions(+), 153 deletions(-) diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 269a728394f2d..b92c3a9dd41f3 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -53,12 +53,12 @@ import org.opensearch.index.shard.ShardId; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.pipeline.PipelinedRequest; -import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.transport.Transport; import java.util.ArrayDeque; @@ -90,7 +90,7 @@ abstract class AbstractSearchAsyncAction exten private final SearchTransportService searchTransportService; private final Executor executor; private final ActionListener listener; - private final SearchRequest request; + private final PipelinedRequest request; /** * Used by subclasses to resolve node ids to DiscoveryNodes. **/ @@ -118,7 +118,6 @@ abstract class AbstractSearchAsyncAction exten private final boolean throttleConcurrentRequests; private final List releasables = new ArrayList<>(); - private final SearchPipelineService searchPipelineService; AbstractSearchAsyncAction( String name, @@ -129,7 +128,7 @@ abstract class AbstractSearchAsyncAction exten Map concreteIndexBoosts, Map> indexRoutings, Executor executor, - SearchRequest request, + PipelinedRequest request, ActionListener listener, GroupShardsIterator shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, @@ -137,8 +136,7 @@ abstract class AbstractSearchAsyncAction exten SearchTask task, SearchPhaseResults resultConsumer, int maxConcurrentRequestsPerNode, - SearchResponse.Clusters clusters, - SearchPipelineService searchPipelineService + SearchResponse.Clusters clusters ) { super(name); final List toSkipIterators = new ArrayList<>(); @@ -174,7 +172,6 @@ abstract class AbstractSearchAsyncAction exten this.indexRoutings = indexRoutings; this.results = resultConsumer; this.clusters = clusters; - this.searchPipelineService = searchPipelineService; } @Override @@ -200,9 +197,10 @@ public final void start() { if (getNumShards() == 0) { // no search shards to search on, bail with empty response // (it happens with search across _all with no indices around and consistent with broadcast operations) - int trackTotalHitsUpTo = request.source() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO - : request.source().trackTotalHitsUpTo() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO - : request.source().trackTotalHitsUpTo(); + SearchSourceBuilder source = request.transformedRequest().source(); + int trackTotalHitsUpTo = source == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO + : source.trackTotalHitsUpTo() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO + : source.trackTotalHitsUpTo(); // total hits is null in the response if the tracking of total hits is disabled boolean withTotalHits = trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED; listener.onResponse( @@ -229,9 +227,10 @@ public final void run() { assert iterator.skip(); skipShard(iterator); } + SearchRequest searchRequest = request.transformedRequest(); if (shardsIts.size() > 0) { - assert request.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults"; - if (request.allowPartialSearchResults() == false) { + assert searchRequest.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults"; + if (searchRequest.allowPartialSearchResults() == false) { final StringBuilder missingShards = new StringBuilder(); // Fail-fast verification of all shards being available for (int index = 0; index < shardsIts.size(); index++) { @@ -376,7 +375,7 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha logger.debug(() -> new ParameterizedMessage("All shards failed for phase: [{}]", getName()), cause); onPhaseFailure(currentPhase, "all shards failed", cause); } else { - Boolean allowPartialResults = request.allowPartialSearchResults(); + Boolean allowPartialResults = request.transformedRequest().allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; if (allowPartialResults == false && successfulOps.get() != getNumShards()) { // check if there are actual failures in the atomic array since @@ -612,7 +611,7 @@ public final SearchTask getTask() { @Override public final SearchRequest getRequest() { - return request; + return request.transformedRequest(); } protected final SearchResponse buildSearchResponse( @@ -643,19 +642,22 @@ boolean buildPointInTimeFromSearchResults() { @Override public void sendSearchResponse(InternalSearchResponse internalSearchResponse, AtomicArray queryResults) { ShardSearchFailure[] failures = buildShardFailures(); - Boolean allowPartialResults = request.allowPartialSearchResults(); + Boolean allowPartialResults = request.transformedRequest().allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; if (allowPartialResults == false && failures.length > 0) { raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures)); } else { final Version minNodeVersion = clusterState.nodes().getMinNodeVersion(); - final String scrollId = request.scroll() != null ? TransportSearchHelper.buildScrollId(queryResults, minNodeVersion) : null; + final String scrollId = request.transformedRequest().scroll() != null + ? TransportSearchHelper.buildScrollId(queryResults, minNodeVersion) + : null; final String searchContextId; if (buildPointInTimeFromSearchResults()) { searchContextId = SearchContextId.encode(queryResults.asList(), aliasFilter, minNodeVersion); } else { - if (request.source() != null && request.source().pointInTimeBuilder() != null) { - searchContextId = request.source().pointInTimeBuilder().getId(); + SearchSourceBuilder source = request.transformedRequest().source(); + if (source != null && source.pointInTimeBuilder() != null) { + searchContextId = source.pointInTimeBuilder().getId(); } else { searchContextId = null; } @@ -677,7 +679,7 @@ public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) */ private void raisePhaseFailure(SearchPhaseExecutionException exception) { // we don't release persistent readers (point in time). - if (request.pointInTimeBuilder() == null) { + if (request.transformedRequest().pointInTimeBuilder() == null) { results.getSuccessfulResults().forEach((entry) -> { if (entry.getContextId() != null) { try { @@ -705,9 +707,7 @@ final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() // From src files the next phase is never null, but from tests this is a possibility. Hence, making sure that // tests pass, we need to do null check on next phase. if (nextPhase != null) { - - final PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(this.getRequest()); - pipelinedRequest.transformSearchPhase(results, this, this.getName(), nextPhase.getName()); + request.transformSearchPhase(results, this, this.getName(), nextPhase.getName()); } executeNextPhase(this, nextPhase); } @@ -741,7 +741,7 @@ public final ShardSearchRequest buildShardSearchRequest(SearchShardIterator shar final String[] routings = indexRoutings.getOrDefault(indexName, Collections.emptySet()).toArray(new String[0]); ShardSearchRequest shardRequest = new ShardSearchRequest( shardIt.getOriginalIndices(), - request, + request.transformedRequest(), shardIt.shardId(), getNumShards(), filter, diff --git a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java index bece269902274..44c6aec39016a 100644 --- a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java @@ -41,7 +41,7 @@ import org.opensearch.search.SearchShardTarget; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; -import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.MinAndMax; import org.opensearch.search.sort.SortOrder; @@ -84,15 +84,14 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction concreteIndexBoosts, Map> indexRoutings, Executor executor, - SearchRequest request, + PipelinedRequest request, ActionListener listener, GroupShardsIterator shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState, SearchTask task, Function, SearchPhase> phaseFactory, - SearchResponse.Clusters clusters, - SearchPipelineService searchPipelineService + SearchResponse.Clusters clusters ) { // We set max concurrent shard requests to the number of shards so no throttling happens for can_match requests super( @@ -112,8 +111,7 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction listener, final GroupShardsIterator shardsIts, final TransportSearchAction.SearchTimeProvider timeProvider, final ClusterState clusterState, final SearchTask task, - SearchResponse.Clusters clusters, - SearchPipelineService searchPipelineService + SearchResponse.Clusters clusters ) { super( "dfs", @@ -96,14 +95,13 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()), - request.getMaxConcurrentShardRequests(), - clusters, - searchPipelineService + request.transformedRequest().getMaxConcurrentShardRequests(), + clusters ); this.queryPhaseResultConsumer = queryPhaseResultConsumer; this.searchPhaseController = searchPhaseController; SearchProgressListener progressListener = task.getProgressListener(); - SearchSourceBuilder sourceBuilder = request.source(); + SearchSourceBuilder sourceBuilder = request.transformedRequest().source(); progressListener.notifyListShards( SearchProgressListener.buildSearchShards(this.shardsIts), SearchProgressListener.buildSearchShards(toSkipShardsIts), diff --git a/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java index 2aaa1d788c5bc..6e72dabfb439b 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -39,10 +39,11 @@ import org.opensearch.cluster.routing.GroupShardsIterator; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.transport.Transport; @@ -76,14 +77,13 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction listener, final GroupShardsIterator shardsIts, final TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState, SearchTask task, - SearchResponse.Clusters clusters, - SearchPipelineService searchPipelineService + SearchResponse.Clusters clusters ) { super( "query", @@ -101,12 +101,11 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction 0; + SearchSourceBuilder source = request.transformedRequest().source(); + boolean hasFetchPhase = source == null ? true : source.size() > 0; progressListener.notifyListShards( SearchProgressListener.buildSearchShards(this.shardsIts), SearchProgressListener.buildSearchShards(toSkipShardsIts), diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index fe7fc2d7ee383..ef83f0450b21a 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -315,7 +315,7 @@ public void executeRequest( @Override public AbstractSearchAsyncAction asyncSearchAction( SearchTask task, - SearchRequest searchRequest, + PipelinedRequest pipelinedRequest, Executor executor, GroupShardsIterator shardsIts, SearchTimeProvider timeProvider, @@ -338,16 +338,15 @@ public AbstractSearchAsyncAction asyncSearchAction( concreteIndexBoosts, indexRoutings, executor, - searchRequest, + pipelinedRequest, listener, shardsIts, timeProvider, clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()), - searchRequest.getMaxConcurrentShardRequests(), - clusters, - searchPipelineService + pipelinedRequest.transformedRequest().getMaxConcurrentShardRequests(), + clusters ) { @Override protected void executePhaseOnShard( @@ -391,11 +390,10 @@ private void executeRequest( relativeStartNanos, System::nanoTime ); - SearchRequest searchRequest; + PipelinedRequest pipelinedRequest; ActionListener listener; try { - PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(originalSearchRequest); - searchRequest = pipelinedRequest.transformedRequest(); + pipelinedRequest = searchPipelineService.resolvePipeline(originalSearchRequest); listener = ActionListener.wrap( r -> originalListener.onResponse(pipelinedRequest.transformResponse(r)), originalListener::onFailure @@ -404,6 +402,7 @@ private void executeRequest( originalListener.onFailure(e); return; } + SearchRequest searchRequest = pipelinedRequest.transformedRequest(); ActionListener rewriteListener = ActionListener.wrap(source -> { if (source != searchRequest.source()) { @@ -430,7 +429,7 @@ private void executeRequest( executeLocalSearch( task, timeProvider, - searchRequest, + pipelinedRequest, localIndices, clusterState, listener, @@ -440,7 +439,7 @@ private void executeRequest( } else { if (shouldMinimizeRoundtrips(searchRequest)) { ccsRemoteReduce( - searchRequest, + pipelinedRequest, localIndices, remoteClusterIndices, timeProvider, @@ -497,7 +496,7 @@ private void executeRequest( executeSearch( (SearchTask) task, timeProvider, - searchRequest, + pipelinedRequest, localIndices, remoteShardIterators, clusterNodeLookup, @@ -545,7 +544,7 @@ static boolean shouldMinimizeRoundtrips(SearchRequest searchRequest) { } static void ccsRemoteReduce( - SearchRequest searchRequest, + PipelinedRequest pipelinedRequest, OriginalIndices localIndices, Map remoteIndices, SearchTimeProvider timeProvider, @@ -553,7 +552,7 @@ static void ccsRemoteReduce( RemoteClusterService remoteClusterService, ThreadPool threadPool, ActionListener listener, - BiConsumer> localSearchConsumer + BiConsumer> localSearchConsumer ) { if (localIndices == null && remoteIndices.size() == 1) { @@ -564,7 +563,7 @@ static void ccsRemoteReduce( boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias); OriginalIndices indices = entry.getValue(); SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest( - searchRequest, + pipelinedRequest.transformedRequest(), indices.indices(), clusterAlias, timeProvider.getAbsoluteStartMillis(), @@ -613,7 +612,7 @@ public void onFailure(Exception e) { }); } else { SearchResponseMerger searchResponseMerger = createSearchResponseMerger( - searchRequest.source(), + pipelinedRequest.transformedRequest().source(), timeProvider, aggReduceContextBuilder ); @@ -626,7 +625,7 @@ public void onFailure(Exception e) { boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias); OriginalIndices indices = entry.getValue(); SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest( - searchRequest, + pipelinedRequest.transformedRequest(), indices.indices(), clusterAlias, timeProvider.getAbsoluteStartMillis(), @@ -657,13 +656,14 @@ public void onFailure(Exception e) { listener ); SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest( - searchRequest, + pipelinedRequest.transformedRequest(), localIndices.indices(), RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, timeProvider.getAbsoluteStartMillis(), false ); - localSearchConsumer.accept(ccsLocalSearchRequest, ccsListener); + + localSearchConsumer.accept(pipelinedRequest.replaceRequest(ccsLocalSearchRequest), ccsListener); } } } @@ -779,7 +779,7 @@ SearchResponse createFinalResponse() { private void executeLocalSearch( Task task, SearchTimeProvider timeProvider, - SearchRequest searchRequest, + PipelinedRequest pipelinedRequest, OriginalIndices localIndices, ClusterState clusterState, ActionListener listener, @@ -789,7 +789,7 @@ private void executeLocalSearch( executeSearch( (SearchTask) task, timeProvider, - searchRequest, + pipelinedRequest, localIndices, Collections.emptyList(), (clusterName, nodeId) -> null, @@ -907,7 +907,7 @@ private Index[] resolveLocalIndices(OriginalIndices localIndices, ClusterState c private void executeSearch( SearchTask task, SearchTimeProvider timeProvider, - SearchRequest searchRequest, + PipelinedRequest pipelinedRequest, OriginalIndices localIndices, List remoteShardIterators, BiFunction remoteConnections, @@ -929,6 +929,7 @@ private void executeSearch( final Map> indexRoutings; final String[] concreteLocalIndices; + final SearchRequest searchRequest = pipelinedRequest.transformedRequest(); if (searchContext != null) { assert searchRequest.pointInTimeBuilder() != null; aliasFilter = searchContext.aliasFilter(); @@ -1010,7 +1011,7 @@ private void executeSearch( ); searchAsyncActionProvider.asyncSearchAction( task, - searchRequest, + pipelinedRequest, asyncSearchExecutor, shardIterators, timeProvider, @@ -1093,7 +1094,7 @@ static GroupShardsIterator mergeShardsIterators( interface SearchAsyncActionProvider { AbstractSearchAsyncAction asyncSearchAction( SearchTask task, - SearchRequest searchRequest, + PipelinedRequest searchRequest, Executor executor, GroupShardsIterator shardIterators, SearchTimeProvider timeProvider, @@ -1111,7 +1112,7 @@ AbstractSearchAsyncAction asyncSearchAction( private AbstractSearchAsyncAction searchAsyncAction( SearchTask task, - SearchRequest searchRequest, + PipelinedRequest pipelinedRequest, Executor executor, GroupShardsIterator shardIterators, SearchTimeProvider timeProvider, @@ -1134,7 +1135,7 @@ private AbstractSearchAsyncAction searchAsyncAction concreteIndexBoosts, indexRoutings, executor, - searchRequest, + pipelinedRequest, listener, shardIterators, timeProvider, @@ -1143,7 +1144,7 @@ private AbstractSearchAsyncAction searchAsyncAction (iter) -> { AbstractSearchAsyncAction action = searchAsyncAction( task, - searchRequest, + pipelinedRequest, executor, iter, timeProvider, @@ -1164,10 +1165,10 @@ public void run() { } }; }, - clusters, - searchPipelineService + clusters ); } else { + final SearchRequest searchRequest = pipelinedRequest.transformedRequest(); final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults( executor, circuitBreaker, @@ -1189,14 +1190,13 @@ public void run() { searchPhaseController, executor, queryResultConsumer, - searchRequest, + pipelinedRequest, listener, shardIterators, timeProvider, clusterState, task, - clusters, - searchPipelineService + clusters ); break; case QUERY_THEN_FETCH: @@ -1210,14 +1210,13 @@ public void run() { searchPhaseController, executor, queryResultConsumer, - searchRequest, + pipelinedRequest, listener, shardIterators, timeProvider, clusterState, task, - clusters, - searchPipelineService + clusters ); break; default: diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java index e45b510d7c760..966d6ba5a3e9b 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java @@ -49,4 +49,21 @@ public SearchPhaseResults transformSe Pipeline getPipeline() { return pipeline; } + + /** + * Wraps a search request with a no-op pipeline. Useful for testing. + * + * @param searchRequest the original search request + * @return a search request associated with a pipeline that does nothing + */ + public static PipelinedRequest wrapSearchRequest(SearchRequest searchRequest) { + return new PipelinedRequest(Pipeline.NO_OP_PIPELINE, searchRequest); + } + + /** + * Wraps the given search request with this request's pipeline. + */ + public PipelinedRequest replaceRequest(SearchRequest searchRequest) { + return new PipelinedRequest(pipeline, searchRequest); + } } diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index f55ce93d019fe..206b8a571bb5b 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -52,7 +52,7 @@ import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; @@ -84,7 +84,6 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase { private final List> resolvedNodes = new ArrayList<>(); private final Set releasedContexts = new CopyOnWriteArraySet<>(); private ExecutorService executor; - private SearchPipelineService searchPipelineService; @Before @Override @@ -155,7 +154,7 @@ private AbstractSearchAsyncAction createAction( Collections.singletonMap("foo", 2.0f), Collections.singletonMap("name", Sets.newHashSet("bar", "baz")), executor, - request, + PipelinedRequest.wrapSearchRequest(request), listener, new GroupShardsIterator<>(Arrays.asList(shards)), timeProvider, @@ -163,8 +162,7 @@ private AbstractSearchAsyncAction createAction( null, results, request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { @Override protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { diff --git a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java index 9876dbdf6f90b..1a743716025d6 100644 --- a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -32,7 +32,6 @@ package org.opensearch.action.search; import org.apache.lucene.util.BytesRef; -import org.mockito.Mockito; import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.OriginalIndices; @@ -48,7 +47,7 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.sort.MinAndMax; import org.opensearch.search.sort.SortBuilders; import org.opensearch.search.sort.SortOrder; @@ -74,8 +73,6 @@ public class CanMatchPreFilterSearchPhaseTests extends OpenSearchTestCase { - private final SearchPipelineService searchPipelineService = Mockito.mock(SearchPipelineService.class); - public void testFilterShards() throws InterruptedException { final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( @@ -127,7 +124,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), null, shardsIter, timeProvider, @@ -140,8 +137,7 @@ public void run() throws IOException { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ); canMatchPhase.start(); @@ -219,7 +215,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), null, shardsIter, timeProvider, @@ -232,8 +228,7 @@ public void run() throws IOException { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ); canMatchPhase.start(); @@ -301,7 +296,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), null, shardsIter, timeProvider, @@ -315,7 +310,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), executor, - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), responseListener, iter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -323,8 +318,7 @@ public void sendCanMatch( null, new ArraySearchPhaseResults<>(iter.size()), randomIntBetween(1, 32), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { @Override @@ -351,8 +345,7 @@ protected void executePhaseOnShard( } } }, - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ); canMatchPhase.start(); @@ -423,7 +416,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), null, shardsIter, timeProvider, @@ -436,8 +429,7 @@ public void run() { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ); canMatchPhase.start(); @@ -523,7 +515,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), null, shardsIter, timeProvider, @@ -536,8 +528,7 @@ public void run() { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ); canMatchPhase.start(); diff --git a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java index 521778dcbf171..53131d884a60a 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java @@ -51,6 +51,7 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; @@ -131,7 +132,7 @@ public void testSkipSearchShards() throws InterruptedException { Collections.emptyMap(), Collections.emptyMap(), null, - request, + PipelinedRequest.wrapSearchRequest(request), responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -139,8 +140,7 @@ public void testSkipSearchShards() throws InterruptedException { null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { @Override @@ -250,7 +250,7 @@ public void testLimitConcurrentShardRequests() throws InterruptedException { Collections.emptyMap(), Collections.emptyMap(), null, - request, + PipelinedRequest.wrapSearchRequest(request), responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -258,8 +258,7 @@ public void testLimitConcurrentShardRequests() throws InterruptedException { null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { @Override @@ -368,7 +367,7 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI Collections.emptyMap(), Collections.emptyMap(), executor, - request, + PipelinedRequest.wrapSearchRequest(request), responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -376,8 +375,7 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { TestSearchResponse response = new TestSearchResponse(); @@ -491,7 +489,7 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI Collections.emptyMap(), Collections.emptyMap(), executor, - request, + PipelinedRequest.wrapSearchRequest(request), responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -499,8 +497,7 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { TestSearchResponse response = new TestSearchResponse(); @@ -605,7 +602,7 @@ public void testAllowPartialResults() throws InterruptedException { Collections.emptyMap(), Collections.emptyMap(), null, - request, + PipelinedRequest.wrapSearchRequest(request), responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -613,8 +610,7 @@ public void testAllowPartialResults() throws InterruptedException { null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { @Override diff --git a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java index 3649ee554c197..e1bf9244b3a6b 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -37,7 +37,6 @@ import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.grouping.CollapseTopFieldDocs; -import org.mockito.Mockito; import org.opensearch.Version; import org.opensearch.action.OriginalIndices; import org.opensearch.cluster.node.DiscoveryNode; @@ -57,7 +56,7 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.sort.SortBuilders; import org.opensearch.test.OpenSearchTestCase; @@ -77,7 +76,6 @@ import static org.hamcrest.Matchers.instanceOf; public class SearchQueryThenFetchAsyncActionTests extends OpenSearchTestCase { - private final SearchPipelineService searchPipelineService = Mockito.mock(SearchPipelineService.class); public void testBottomFieldSort() throws Exception { testCase(false, false); @@ -212,14 +210,13 @@ public void sendExecuteQuery( controller, executor, resultConsumer, - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), null, shardsIter, timeProvider, null, task, - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { @Override protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { diff --git a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java index 51d9a06c9ac43..96ffb016604f9 100644 --- a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java @@ -74,6 +74,7 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.sort.SortBuilders; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.transport.MockTransportService; @@ -458,14 +459,14 @@ public void testCCSRemoteReduceMergeFails() throws Exception { SearchRequest searchRequest = new SearchRequest(); searchRequest.preference("null_target"); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference failure = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(r -> fail("no response expected"), failure::set), latch ); TransportSearchAction.ccsRemoteReduce( - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), localIndices, remoteIndicesByCluster, timeProvider, @@ -478,8 +479,8 @@ public void testCCSRemoteReduceMergeFails() throws Exception { if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -514,14 +515,14 @@ public void testCCSRemoteReduce() throws Exception { { SearchRequest searchRequest = new SearchRequest(); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference response = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(response::set, e -> fail("no failures expected")), latch ); TransportSearchAction.ccsRemoteReduce( - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), localIndices, remoteIndicesByCluster, timeProvider, @@ -534,8 +535,8 @@ public void testCCSRemoteReduce() throws Exception { if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -551,14 +552,14 @@ public void testCCSRemoteReduce() throws Exception { SearchRequest searchRequest = new SearchRequest(); searchRequest.preference("index_not_found"); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference failure = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(r -> fail("no response expected"), failure::set), latch ); TransportSearchAction.ccsRemoteReduce( - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), localIndices, remoteIndicesByCluster, timeProvider, @@ -571,8 +572,8 @@ public void testCCSRemoteReduce() throws Exception { if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -609,14 +610,14 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti { SearchRequest searchRequest = new SearchRequest(); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference failure = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(r -> fail("no response expected"), failure::set), latch ); TransportSearchAction.ccsRemoteReduce( - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), localIndices, remoteIndicesByCluster, timeProvider, @@ -629,8 +630,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -649,14 +650,14 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti { SearchRequest searchRequest = new SearchRequest(); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference response = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(response::set, e -> fail("no failures expected")), latch ); TransportSearchAction.ccsRemoteReduce( - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), localIndices, remoteIndicesByCluster, timeProvider, @@ -669,8 +670,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -700,14 +701,14 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti { SearchRequest searchRequest = new SearchRequest(); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference response = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(response::set, e -> fail("no failures expected")), latch ); TransportSearchAction.ccsRemoteReduce( - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), localIndices, remoteIndicesByCluster, timeProvider, @@ -720,8 +721,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } From 15f25be3a79353a0cf28c0dd5d655011ee39364a Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Mon, 19 Jun 2023 23:28:49 -0700 Subject: [PATCH 03/10] Renamed SearchPhaseInjectorProcessor to SearchPhaseResultsProcessor and fixed the comments Signed-off-by: Navneet Verma --- .../search/CanMatchPreFilterSearchPhase.java | 2 +- .../opensearch/action/search/SearchPhase.java | 3 +- .../plugins/SearchPipelinePlugin.java | 4 +- .../opensearch/search/pipeline/Pipeline.java | 41 ++++++++++++------- ....java => SearchPhaseResultsProcessor.java} | 4 +- .../pipeline/SearchPipelineService.java | 4 +- .../pipeline/SearchPipelineServiceTests.java | 30 +++++++------- 7 files changed, 50 insertions(+), 38 deletions(-) rename server/src/main/java/org/opensearch/search/pipeline/{SearchPhaseInjectorProcessor.java => SearchPhaseResultsProcessor.java} (93%) diff --git a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java index 4226a814e096b..cb1dbaa3ff9d5 100644 --- a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java @@ -95,7 +95,7 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction> getResponseProce * in pipeline configurations, and the value is a {@link org.opensearch.search.pipeline.Processor.Factory} * to create the processor from a given pipeline configuration. */ - default Map> getPhaseInjectorProcessors(Processor.Parameters parameters) { + default Map> getPhaseResultsProcessors(Processor.Parameters parameters) { return Collections.emptyMap(); } } diff --git a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java index 8ae74c9324ab2..09dbc6860b5ff 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java +++ b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java @@ -38,7 +38,7 @@ class Pipeline { public static final String REQUEST_PROCESSORS_KEY = "request_processors"; public static final String RESPONSE_PROCESSORS_KEY = "response_processors"; - public static final String PHASE_PROCESSORS_KEY = "phase_injector_processors"; + public static final String PHASE_PROCESSORS_KEY = "phase_results_processors"; private final String id; private final String description; private final Integer version; @@ -48,7 +48,7 @@ class Pipeline { private final List searchRequestProcessors; private final List searchResponseProcessors; private final NamedWriteableRegistry namedWriteableRegistry; - private final List searchPhaseInjectorProcessors; + private final List searchPhaseResultsProcessors; private Pipeline( String id, @@ -56,7 +56,7 @@ private Pipeline( @Nullable Integer version, List requestProcessors, List responseProcessors, - List phaseInjectorProcessors, + List phaseResultsProcessors, NamedWriteableRegistry namedWriteableRegistry ) { this.id = id; @@ -64,7 +64,7 @@ private Pipeline( this.version = version; this.searchRequestProcessors = requestProcessors; this.searchResponseProcessors = responseProcessors; - this.searchPhaseInjectorProcessors = phaseInjectorProcessors; + this.searchPhaseResultsProcessors = phaseResultsProcessors; this.namedWriteableRegistry = namedWriteableRegistry; } @@ -73,7 +73,7 @@ static Pipeline create( Map config, Map> requestProcessorFactories, Map> responseProcessorFactories, - Map> phaseInjectorProcessorFactories, + Map> phaseResultsProcessorFactories, NamedWriteableRegistry namedWriteableRegistry ) throws Exception { String description = ConfigurationUtils.readOptionalStringProperty(null, null, config, DESCRIPTION_KEY); @@ -94,7 +94,10 @@ static Pipeline create( PHASE_PROCESSORS_KEY ); final List responseProcessors = readProcessors(responseProcessorFactories, responseProcessorConfigs); - final List phaseProcessors = readProcessors(phaseInjectorProcessorFactories, phaseProcessorConfigs); + final List phaseResultsProcessors = readProcessors( + phaseResultsProcessorFactories, + phaseProcessorConfigs + ); if (config.isEmpty() == false) { throw new OpenSearchParseException( @@ -104,7 +107,15 @@ static Pipeline create( + Arrays.toString(config.keySet().toArray()) ); } - return new Pipeline(id, description, version, requestProcessors, responseProcessors, phaseProcessors, namedWriteableRegistry); + return new Pipeline( + id, + description, + version, + requestProcessors, + responseProcessors, + phaseResultsProcessors, + namedWriteableRegistry + ); } private static List readProcessors( @@ -132,10 +143,10 @@ private static List readProcessors( List flattenAllProcessors() { List allProcessors = new ArrayList<>( - searchRequestProcessors.size() + searchResponseProcessors.size() + searchPhaseInjectorProcessors.size() + searchRequestProcessors.size() + searchResponseProcessors.size() + searchPhaseResultsProcessors.size() ); allProcessors.addAll(searchRequestProcessors); - allProcessors.addAll(searchPhaseInjectorProcessors); + allProcessors.addAll(searchPhaseResultsProcessors); allProcessors.addAll(searchResponseProcessors); return allProcessors; } @@ -160,8 +171,8 @@ List getSearchResponseProcessors() { return searchResponseProcessors; } - List getSearchPhaseInjectorProcessors() { - return searchPhaseInjectorProcessors; + List getSearchPhaseResultsProcessors() { + return searchPhaseResultsProcessors; } SearchRequest transformRequest(SearchRequest request) throws Exception { @@ -210,10 +221,10 @@ SearchPhaseResults runSearchPhaseTran ) throws SearchPipelineProcessingException { try { - for (SearchPhaseInjectorProcessor searchPhaseInjectorProcessor : searchPhaseInjectorProcessors) { - if (currentPhase.equals(searchPhaseInjectorProcessor.getBeforePhase().getName()) - && nextPhase.equals(searchPhaseInjectorProcessor.getAfterPhase().getName())) { - searchPhaseResult = searchPhaseInjectorProcessor.execute(searchPhaseResult, context); + for (SearchPhaseResultsProcessor searchPhaseResultsProcessor : searchPhaseResultsProcessors) { + if (currentPhase.equals(searchPhaseResultsProcessor.getBeforePhase().getName()) + && nextPhase.equals(searchPhaseResultsProcessor.getAfterPhase().getName())) { + searchPhaseResult = searchPhaseResultsProcessor.process(searchPhaseResult, context); } } return searchPhaseResult; diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseInjectorProcessor.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java similarity index 93% rename from server/src/main/java/org/opensearch/search/pipeline/SearchPhaseInjectorProcessor.java rename to server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java index 0d4f7950596cf..90f89a63b566e 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseInjectorProcessor.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java @@ -16,8 +16,8 @@ /** * Creates a processor that runs between Phases of the Search. */ -public interface SearchPhaseInjectorProcessor extends Processor { - SearchPhaseResults execute( +public interface SearchPhaseResultsProcessor extends Processor { + SearchPhaseResults process( final SearchPhaseResults searchPhaseResult, final SearchPhaseContext searchPhaseContext ); diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java index 824253fa9294c..1dbc2d0609cfc 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java @@ -73,7 +73,7 @@ public class SearchPipelineService implements ClusterStateApplier, ReportingServ private final Map> requestProcessorFactories; private final Map> responseProcessorFactories; - private final Map> phaseInjectorProcessorFactories; + private final Map> phaseInjectorProcessorFactories; private volatile Map pipelines = Collections.emptyMap(); private final ThreadPool threadPool; private final List> searchPipelineClusterStateListeners = new CopyOnWriteArrayList<>(); @@ -114,7 +114,7 @@ public SearchPipelineService( ); this.requestProcessorFactories = processorFactories(searchPipelinePlugins, p -> p.getRequestProcessors(parameters)); this.responseProcessorFactories = processorFactories(searchPipelinePlugins, p -> p.getResponseProcessors(parameters)); - this.phaseInjectorProcessorFactories = processorFactories(searchPipelinePlugins, p -> p.getPhaseInjectorProcessors(parameters)); + this.phaseInjectorProcessorFactories = processorFactories(searchPipelinePlugins, p -> p.getPhaseResultsProcessors(parameters)); putPipelineTaskKey = clusterService.registerClusterManagerTask(ClusterManagerTaskKeys.PUT_SEARCH_PIPELINE_KEY, true); deletePipelineTaskKey = clusterService.registerClusterManagerTask(ClusterManagerTaskKeys.DELETE_SEARCH_PIPELINE_KEY, true); this.isEnabled = isEnabled; diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index c7a2ae98624a3..b5f377a029287 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -84,7 +84,7 @@ public Map> getResponseProces } @Override - public Map> getPhaseInjectorProcessors(Processor.Parameters parameters) { + public Map> getPhaseResultsProcessors(Processor.Parameters parameters) { return Map.of("zoe", (factories, tag, description, config) -> null); } }; @@ -263,10 +263,10 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp } } - private static class FakeSearchPhaseInjectorProcessor extends FakeProcessor implements SearchPhaseInjectorProcessor { + private static class FakeSearchPhaseResultsProcessor extends FakeProcessor implements SearchPhaseResultsProcessor { private Consumer querySearchResultConsumer; - public FakeSearchPhaseInjectorProcessor( + public FakeSearchPhaseResultsProcessor( String type, String tag, String description, @@ -277,7 +277,7 @@ public FakeSearchPhaseInjectorProcessor( } @Override - public SearchPhaseResults execute( + public SearchPhaseResults process( SearchPhaseResults searchPhaseResult, SearchPhaseContext searchPhaseContext ) { @@ -315,11 +315,11 @@ private SearchPipelineService createWithProcessors() { return new FakeResponseProcessor("fixed_score", tag, description, rsp -> rsp.getHits().forEach(h -> h.score(score))); }); - Map> searchPhaseProcessors = new HashMap<>(); + Map> searchPhaseProcessors = new HashMap<>(); searchPhaseProcessors.put("max_score", (processorFactories, tag, description, config) -> { final float finalScore = config.containsKey("score") ? ((Number) config.remove("score")).floatValue() : 100f; final Consumer querySearchResultConsumer = (result) -> result.queryResult().topDocs().maxScore = finalScore; - return new FakeSearchPhaseInjectorProcessor("max_score", tag, description, querySearchResultConsumer); + return new FakeSearchPhaseResultsProcessor("max_score", tag, description, querySearchResultConsumer); }); return createWithProcessors(requestProcessors, responseProcessors, searchPhaseProcessors); @@ -334,7 +334,7 @@ protected NamedWriteableRegistry writableRegistry() { private SearchPipelineService createWithProcessors( Map> requestProcessors, Map> responseProcessors, - Map> phaseProcessors + Map> phaseProcessors ) { Client client = mock(Client.class); ThreadPool threadPool = mock(ThreadPool.class); @@ -361,7 +361,7 @@ public Map> getResponseProces } @Override - public Map> getPhaseInjectorProcessors( + public Map> getPhaseResultsProcessors( Processor.Parameters parameters ) { return phaseProcessors; @@ -386,7 +386,7 @@ public void testUpdatePipelines() { "{ " + "\"request_processors\" : [ { \"scale_request_size\": { \"scale\" : 2 } } ], " + "\"response_processors\" : [ { \"fixed_score\" : { \"score\" : 1.0 } } ]," - + "\"phase_injector_processors\" : [ { \"max_score\" : { \"score\": 100 } } ]" + + "\"phase_results_processors\" : [ { \"max_score\" : { \"score\": 100 } } ]" + "}" ), XContentType.JSON @@ -404,10 +404,10 @@ public void testUpdatePipelines() { "scale_request_size", searchPipelineService.getPipelines().get("_id").pipeline.getSearchRequestProcessors().get(0).getType() ); - assertEquals(1, searchPipelineService.getPipelines().get("_id").pipeline.getSearchPhaseInjectorProcessors().size()); + assertEquals(1, searchPipelineService.getPipelines().get("_id").pipeline.getSearchPhaseResultsProcessors().size()); assertEquals( "max_score", - searchPipelineService.getPipelines().get("_id").pipeline.getSearchPhaseInjectorProcessors().get(0).getType() + searchPipelineService.getPipelines().get("_id").pipeline.getSearchPhaseResultsProcessors().get(0).getType() ); assertEquals(1, searchPipelineService.getPipelines().get("_id").pipeline.getSearchResponseProcessors().size()); assertEquals( @@ -446,7 +446,7 @@ public void testPutPipeline() { assertEquals("empty pipeline", pipeline.pipeline.getDescription()); assertEquals(0, pipeline.pipeline.getSearchRequestProcessors().size()); assertEquals(0, pipeline.pipeline.getSearchResponseProcessors().size()); - assertEquals(0, pipeline.pipeline.getSearchPhaseInjectorProcessors().size()); + assertEquals(0, pipeline.pipeline.getSearchPhaseResultsProcessors().size()); } public void testPutInvalidPipeline() throws IllegalAccessException { @@ -650,7 +650,7 @@ public void testTransformSearchPhase() { "p1", new PipelineConfiguration( "p1", - new BytesArray("{\"phase_injector_processors\" : [ { \"max_score\" : { } } ]}"), + new BytesArray("{\"phase_results_processors\" : [ { \"max_score\" : { } } ]}"), XContentType.JSON ) ) @@ -745,7 +745,7 @@ public void testGetPipelines() { "p3", new PipelineConfiguration( "p3", - new BytesArray("{\"phase_injector_processors\" : [ { \"max_score\" : { } } ]}"), + new BytesArray("{\"phase_results_processors\" : [ { \"max_score\" : { } } ]}"), XContentType.JSON ) ) @@ -802,7 +802,7 @@ public void testValidatePipeline() throws Exception { "{" + "\"request_processors\": [{ \"scale_request_size\": { \"scale\" : 2 } }]," + "\"response_processors\": [{ \"fixed_score\": { \"score\" : 2 } }]," - + "\"phase_injector_processors\" : [ { \"max_score\" : { } } ]" + + "\"phase_results_processors\" : [ { \"max_score\" : { } } ]" + "}" ), XContentType.JSON From 213c7c4a3c0f49d225d30f97784b77891eb2a6c9 Mon Sep 17 00:00:00 2001 From: Michael Froh Date: Fri, 23 Jun 2023 19:50:43 +0000 Subject: [PATCH 04/10] Make PipelinedSearchRequest extend SearchRequest Rather than wrapping a SearchRequest in a PipelinedSearchRequest, changes are less intrusive if we say that a PipelinedSearchRequest "is a" SearchRequest. Signed-off-by: Michael Froh --- .../search/AbstractSearchAsyncAction.java | 46 +++++++-------- .../search/CanMatchPreFilterSearchPhase.java | 3 +- .../SearchDfsQueryThenFetchAsyncAction.java | 7 +-- .../SearchQueryThenFetchAsyncAction.java | 13 ++--- .../action/search/TransportSearchAction.java | 56 +++++++++---------- .../search/pipeline/PipelinedRequest.java | 28 +--------- .../AbstractSearchAsyncActionTests.java | 3 +- .../CanMatchPreFilterSearchPhaseTests.java | 13 ++--- .../action/search/SearchAsyncActionTests.java | 15 ++--- .../SearchQueryThenFetchAsyncActionTests.java | 4 +- .../search/TransportSearchActionTests.java | 49 ++++++++-------- .../pipeline/SearchPipelineServiceTests.java | 14 ++--- 12 files changed, 101 insertions(+), 150 deletions(-) diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 04814398802ff..6e10c5b7dacc2 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -53,7 +53,6 @@ import org.opensearch.index.shard.ShardId; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; -import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.SearchContext; @@ -90,7 +89,7 @@ abstract class AbstractSearchAsyncAction exten private final SearchTransportService searchTransportService; private final Executor executor; private final ActionListener listener; - private final PipelinedRequest request; + private final SearchRequest request; /** * Used by subclasses to resolve node ids to DiscoveryNodes. **/ @@ -128,7 +127,7 @@ abstract class AbstractSearchAsyncAction exten Map concreteIndexBoosts, Map> indexRoutings, Executor executor, - PipelinedRequest request, + SearchRequest request, ActionListener listener, GroupShardsIterator shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, @@ -197,10 +196,9 @@ public final void start() { if (getNumShards() == 0) { // no search shards to search on, bail with empty response // (it happens with search across _all with no indices around and consistent with broadcast operations) - SearchSourceBuilder source = request.transformedRequest().source(); - int trackTotalHitsUpTo = source == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO - : source.trackTotalHitsUpTo() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO - : source.trackTotalHitsUpTo(); + int trackTotalHitsUpTo = request.source() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO + : request.source().trackTotalHitsUpTo() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO + : request.source().trackTotalHitsUpTo(); // total hits is null in the response if the tracking of total hits is disabled boolean withTotalHits = trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED; listener.onResponse( @@ -227,10 +225,9 @@ public final void run() { assert iterator.skip(); skipShard(iterator); } - SearchRequest searchRequest = request.transformedRequest(); if (shardsIts.size() > 0) { - assert searchRequest.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults"; - if (searchRequest.allowPartialSearchResults() == false) { + assert request.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults"; + if (request.allowPartialSearchResults() == false) { final StringBuilder missingShards = new StringBuilder(); // Fail-fast verification of all shards being available for (int index = 0; index < shardsIts.size(); index++) { @@ -375,7 +372,7 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha logger.debug(() -> new ParameterizedMessage("All shards failed for phase: [{}]", getName()), cause); onPhaseFailure(currentPhase, "all shards failed", cause); } else { - Boolean allowPartialResults = request.transformedRequest().allowPartialSearchResults(); + Boolean allowPartialResults = request.allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; if (allowPartialResults == false && successfulOps.get() != getNumShards()) { // check if there are actual failures in the atomic array since @@ -611,7 +608,7 @@ public final SearchTask getTask() { @Override public final SearchRequest getRequest() { - return request.transformedRequest(); + return request; } protected final SearchResponse buildSearchResponse( @@ -642,22 +639,19 @@ boolean buildPointInTimeFromSearchResults() { @Override public void sendSearchResponse(InternalSearchResponse internalSearchResponse, AtomicArray queryResults) { ShardSearchFailure[] failures = buildShardFailures(); - Boolean allowPartialResults = request.transformedRequest().allowPartialSearchResults(); + Boolean allowPartialResults = request.allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; if (allowPartialResults == false && failures.length > 0) { raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures)); } else { final Version minNodeVersion = clusterState.nodes().getMinNodeVersion(); - final String scrollId = request.transformedRequest().scroll() != null - ? TransportSearchHelper.buildScrollId(queryResults, minNodeVersion) - : null; + final String scrollId = request.scroll() != null ? TransportSearchHelper.buildScrollId(queryResults, minNodeVersion) : null; final String searchContextId; if (buildPointInTimeFromSearchResults()) { searchContextId = SearchContextId.encode(queryResults.asList(), aliasFilter, minNodeVersion); } else { - SearchSourceBuilder source = request.transformedRequest().source(); - if (source != null && source.pointInTimeBuilder() != null) { - searchContextId = source.pointInTimeBuilder().getId(); + if (request.source() != null && request.source().pointInTimeBuilder() != null) { + searchContextId = request.source().pointInTimeBuilder().getId(); } else { searchContextId = null; } @@ -679,7 +673,7 @@ public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) */ private void raisePhaseFailure(SearchPhaseExecutionException exception) { // we don't release persistent readers (point in time). - if (request.transformedRequest().pointInTimeBuilder() == null) { + if (request.pointInTimeBuilder() == null) { results.getSuccessfulResults().forEach((entry) -> { if (entry.getContextId() != null) { try { @@ -704,10 +698,12 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) { */ final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() final SearchPhase nextPhase = getNextPhase(results, this); - // From src files the next phase is never null, but from tests this is a possibility. Hence, making sure that - // tests pass, we need to do null check on next phase. - if (nextPhase != null) { - request.transformSearchPhase(results, this, this.getName(), nextPhase.getName()); + if (request instanceof PipelinedRequest && nextPhase != null) { + // From src files the next phase is never null, but from tests this is a possibility. Hence, making sure that + // tests pass, we need to do null check on next phase. + if (nextPhase != null) { + ((PipelinedRequest) request).transformSearchPhase(results, this, this.getName(), nextPhase.getName()); + } } executeNextPhase(this, nextPhase); } @@ -741,7 +737,7 @@ public final ShardSearchRequest buildShardSearchRequest(SearchShardIterator shar final String[] routings = indexRoutings.getOrDefault(indexName, Collections.emptySet()).toArray(new String[0]); ShardSearchRequest shardRequest = new ShardSearchRequest( shardIt.getOriginalIndices(), - request.transformedRequest(), + request, shardIt.shardId(), getNumShards(), filter, diff --git a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java index cb1dbaa3ff9d5..c026c72f77f00 100644 --- a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java @@ -41,7 +41,6 @@ import org.opensearch.search.SearchShardTarget; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; -import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.MinAndMax; import org.opensearch.search.sort.SortOrder; @@ -84,7 +83,7 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction concreteIndexBoosts, Map> indexRoutings, Executor executor, - PipelinedRequest request, + SearchRequest request, ActionListener listener, GroupShardsIterator shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, diff --git a/server/src/main/java/org/opensearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 14dc898d5d999..71a986c0e15f7 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -41,7 +41,6 @@ import org.opensearch.search.dfs.AggregatedDfs; import org.opensearch.search.dfs.DfsSearchResult; import org.opensearch.search.internal.AliasFilter; -import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.transport.Transport; import java.util.List; @@ -71,7 +70,7 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction final SearchPhaseController searchPhaseController, final Executor executor, final QueryPhaseResultConsumer queryPhaseResultConsumer, - final PipelinedRequest request, + final SearchRequest request, final ActionListener listener, final GroupShardsIterator shardsIts, final TransportSearchAction.SearchTimeProvider timeProvider, @@ -95,13 +94,13 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()), - request.transformedRequest().getMaxConcurrentShardRequests(), + request.getMaxConcurrentShardRequests(), clusters ); this.queryPhaseResultConsumer = queryPhaseResultConsumer; this.searchPhaseController = searchPhaseController; SearchProgressListener progressListener = task.getProgressListener(); - SearchSourceBuilder sourceBuilder = request.transformedRequest().source(); + SearchSourceBuilder sourceBuilder = request.source(); progressListener.notifyListShards( SearchProgressListener.buildSearchShards(this.shardsIts), SearchProgressListener.buildSearchShards(toSkipShardsIts), diff --git a/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java index 6e72dabfb439b..1ead14aac6b51 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -39,11 +39,9 @@ import org.opensearch.cluster.routing.GroupShardsIterator; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; -import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.transport.Transport; @@ -77,7 +75,7 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction listener, final GroupShardsIterator shardsIts, final TransportSearchAction.SearchTimeProvider timeProvider, @@ -101,11 +99,11 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction 0; + boolean hasFetchPhase = request.source() == null ? true : request.source().size() > 0; progressListener.notifyListShards( SearchProgressListener.buildSearchShards(this.shardsIts), SearchProgressListener.buildSearchShards(toSkipShardsIts), diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index ef83f0450b21a..df2170cbe2af1 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -315,7 +315,7 @@ public void executeRequest( @Override public AbstractSearchAsyncAction asyncSearchAction( SearchTask task, - PipelinedRequest pipelinedRequest, + SearchRequest searchRequest, Executor executor, GroupShardsIterator shardsIts, SearchTimeProvider timeProvider, @@ -338,14 +338,14 @@ public AbstractSearchAsyncAction asyncSearchAction( concreteIndexBoosts, indexRoutings, executor, - pipelinedRequest, + searchRequest, listener, shardsIts, timeProvider, clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()), - pipelinedRequest.transformedRequest().getMaxConcurrentShardRequests(), + searchRequest.getMaxConcurrentShardRequests(), clusters ) { @Override @@ -390,19 +390,18 @@ private void executeRequest( relativeStartNanos, System::nanoTime ); - PipelinedRequest pipelinedRequest; + PipelinedRequest searchRequest; ActionListener listener; try { - pipelinedRequest = searchPipelineService.resolvePipeline(originalSearchRequest); + searchRequest = searchPipelineService.resolvePipeline(originalSearchRequest); listener = ActionListener.wrap( - r -> originalListener.onResponse(pipelinedRequest.transformResponse(r)), + r -> originalListener.onResponse(searchRequest.transformResponse(r)), originalListener::onFailure ); } catch (Exception e) { originalListener.onFailure(e); return; } - SearchRequest searchRequest = pipelinedRequest.transformedRequest(); ActionListener rewriteListener = ActionListener.wrap(source -> { if (source != searchRequest.source()) { @@ -429,7 +428,7 @@ private void executeRequest( executeLocalSearch( task, timeProvider, - pipelinedRequest, + searchRequest, localIndices, clusterState, listener, @@ -439,7 +438,7 @@ private void executeRequest( } else { if (shouldMinimizeRoundtrips(searchRequest)) { ccsRemoteReduce( - pipelinedRequest, + searchRequest, localIndices, remoteClusterIndices, timeProvider, @@ -496,7 +495,7 @@ private void executeRequest( executeSearch( (SearchTask) task, timeProvider, - pipelinedRequest, + searchRequest, localIndices, remoteShardIterators, clusterNodeLookup, @@ -544,7 +543,7 @@ static boolean shouldMinimizeRoundtrips(SearchRequest searchRequest) { } static void ccsRemoteReduce( - PipelinedRequest pipelinedRequest, + SearchRequest searchRequest, OriginalIndices localIndices, Map remoteIndices, SearchTimeProvider timeProvider, @@ -552,7 +551,7 @@ static void ccsRemoteReduce( RemoteClusterService remoteClusterService, ThreadPool threadPool, ActionListener listener, - BiConsumer> localSearchConsumer + BiConsumer> localSearchConsumer ) { if (localIndices == null && remoteIndices.size() == 1) { @@ -563,7 +562,7 @@ static void ccsRemoteReduce( boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias); OriginalIndices indices = entry.getValue(); SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest( - pipelinedRequest.transformedRequest(), + searchRequest, indices.indices(), clusterAlias, timeProvider.getAbsoluteStartMillis(), @@ -612,7 +611,7 @@ public void onFailure(Exception e) { }); } else { SearchResponseMerger searchResponseMerger = createSearchResponseMerger( - pipelinedRequest.transformedRequest().source(), + searchRequest.source(), timeProvider, aggReduceContextBuilder ); @@ -625,7 +624,7 @@ public void onFailure(Exception e) { boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias); OriginalIndices indices = entry.getValue(); SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest( - pipelinedRequest.transformedRequest(), + searchRequest, indices.indices(), clusterAlias, timeProvider.getAbsoluteStartMillis(), @@ -656,14 +655,13 @@ public void onFailure(Exception e) { listener ); SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest( - pipelinedRequest.transformedRequest(), + searchRequest, localIndices.indices(), RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, timeProvider.getAbsoluteStartMillis(), false ); - - localSearchConsumer.accept(pipelinedRequest.replaceRequest(ccsLocalSearchRequest), ccsListener); + localSearchConsumer.accept(ccsLocalSearchRequest, ccsListener); } } } @@ -779,7 +777,7 @@ SearchResponse createFinalResponse() { private void executeLocalSearch( Task task, SearchTimeProvider timeProvider, - PipelinedRequest pipelinedRequest, + SearchRequest searchRequest, OriginalIndices localIndices, ClusterState clusterState, ActionListener listener, @@ -789,7 +787,7 @@ private void executeLocalSearch( executeSearch( (SearchTask) task, timeProvider, - pipelinedRequest, + searchRequest, localIndices, Collections.emptyList(), (clusterName, nodeId) -> null, @@ -907,7 +905,7 @@ private Index[] resolveLocalIndices(OriginalIndices localIndices, ClusterState c private void executeSearch( SearchTask task, SearchTimeProvider timeProvider, - PipelinedRequest pipelinedRequest, + SearchRequest searchRequest, OriginalIndices localIndices, List remoteShardIterators, BiFunction remoteConnections, @@ -929,7 +927,6 @@ private void executeSearch( final Map> indexRoutings; final String[] concreteLocalIndices; - final SearchRequest searchRequest = pipelinedRequest.transformedRequest(); if (searchContext != null) { assert searchRequest.pointInTimeBuilder() != null; aliasFilter = searchContext.aliasFilter(); @@ -1011,7 +1008,7 @@ private void executeSearch( ); searchAsyncActionProvider.asyncSearchAction( task, - pipelinedRequest, + searchRequest, asyncSearchExecutor, shardIterators, timeProvider, @@ -1094,7 +1091,7 @@ static GroupShardsIterator mergeShardsIterators( interface SearchAsyncActionProvider { AbstractSearchAsyncAction asyncSearchAction( SearchTask task, - PipelinedRequest searchRequest, + SearchRequest searchRequest, Executor executor, GroupShardsIterator shardIterators, SearchTimeProvider timeProvider, @@ -1112,7 +1109,7 @@ AbstractSearchAsyncAction asyncSearchAction( private AbstractSearchAsyncAction searchAsyncAction( SearchTask task, - PipelinedRequest pipelinedRequest, + SearchRequest searchRequest, Executor executor, GroupShardsIterator shardIterators, SearchTimeProvider timeProvider, @@ -1135,7 +1132,7 @@ private AbstractSearchAsyncAction searchAsyncAction concreteIndexBoosts, indexRoutings, executor, - pipelinedRequest, + searchRequest, listener, shardIterators, timeProvider, @@ -1144,7 +1141,7 @@ private AbstractSearchAsyncAction searchAsyncAction (iter) -> { AbstractSearchAsyncAction action = searchAsyncAction( task, - pipelinedRequest, + searchRequest, executor, iter, timeProvider, @@ -1168,7 +1165,6 @@ public void run() { clusters ); } else { - final SearchRequest searchRequest = pipelinedRequest.transformedRequest(); final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults( executor, circuitBreaker, @@ -1190,7 +1186,7 @@ public void run() { searchPhaseController, executor, queryResultConsumer, - pipelinedRequest, + searchRequest, listener, shardIterators, timeProvider, @@ -1210,7 +1206,7 @@ public void run() { searchPhaseController, executor, queryResultConsumer, - pipelinedRequest, + searchRequest, listener, shardIterators, timeProvider, diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java index 966d6ba5a3e9b..eb5fd4b6c4c26 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java @@ -19,21 +19,16 @@ * * @opensearch.internal */ -public final class PipelinedRequest { +public final class PipelinedRequest extends SearchRequest { private final Pipeline pipeline; - private final SearchRequest transformedRequest; PipelinedRequest(Pipeline pipeline, SearchRequest transformedRequest) { + super(transformedRequest); this.pipeline = pipeline; - this.transformedRequest = transformedRequest; } public SearchResponse transformResponse(SearchResponse response) { - return pipeline.transformResponse(transformedRequest, response); - } - - public SearchRequest transformedRequest() { - return transformedRequest; + return pipeline.transformResponse(this, response); } public SearchPhaseResults transformSearchPhase( @@ -49,21 +44,4 @@ public SearchPhaseResults transformSe Pipeline getPipeline() { return pipeline; } - - /** - * Wraps a search request with a no-op pipeline. Useful for testing. - * - * @param searchRequest the original search request - * @return a search request associated with a pipeline that does nothing - */ - public static PipelinedRequest wrapSearchRequest(SearchRequest searchRequest) { - return new PipelinedRequest(Pipeline.NO_OP_PIPELINE, searchRequest); - } - - /** - * Wraps the given search request with this request's pipeline. - */ - public PipelinedRequest replaceRequest(SearchRequest searchRequest) { - return new PipelinedRequest(pipeline, searchRequest); - } } diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index 206b8a571bb5b..ad2657517df9a 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -52,7 +52,6 @@ import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; @@ -154,7 +153,7 @@ private AbstractSearchAsyncAction createAction( Collections.singletonMap("foo", 2.0f), Collections.singletonMap("name", Sets.newHashSet("bar", "baz")), executor, - PipelinedRequest.wrapSearchRequest(request), + request, listener, new GroupShardsIterator<>(Arrays.asList(shards)), timeProvider, diff --git a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java index 1a743716025d6..2e3ff166a6a53 100644 --- a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -47,7 +47,6 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.sort.MinAndMax; import org.opensearch.search.sort.SortBuilders; import org.opensearch.search.sort.SortOrder; @@ -124,7 +123,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, null, shardsIter, timeProvider, @@ -215,7 +214,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, null, shardsIter, timeProvider, @@ -296,7 +295,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, null, shardsIter, timeProvider, @@ -310,7 +309,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), executor, - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, responseListener, iter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -416,7 +415,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, null, shardsIter, timeProvider, @@ -515,7 +514,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, null, shardsIter, timeProvider, diff --git a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java index 53131d884a60a..cf838682aa717 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java @@ -31,7 +31,6 @@ package org.opensearch.action.search; -import org.mockito.Mockito; import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.OriginalIndices; @@ -51,8 +50,6 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; -import org.opensearch.search.pipeline.PipelinedRequest; -import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportException; @@ -81,8 +78,6 @@ public class SearchAsyncActionTests extends OpenSearchTestCase { - private SearchPipelineService searchPipelineService = Mockito.mock(SearchPipelineService.class); - public void testSkipSearchShards() throws InterruptedException { SearchRequest request = new SearchRequest(); request.allowPartialSearchResults(true); @@ -132,7 +127,7 @@ public void testSkipSearchShards() throws InterruptedException { Collections.emptyMap(), Collections.emptyMap(), null, - PipelinedRequest.wrapSearchRequest(request), + request, responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -250,7 +245,7 @@ public void testLimitConcurrentShardRequests() throws InterruptedException { Collections.emptyMap(), Collections.emptyMap(), null, - PipelinedRequest.wrapSearchRequest(request), + request, responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -367,7 +362,7 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI Collections.emptyMap(), Collections.emptyMap(), executor, - PipelinedRequest.wrapSearchRequest(request), + request, responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -489,7 +484,7 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI Collections.emptyMap(), Collections.emptyMap(), executor, - PipelinedRequest.wrapSearchRequest(request), + request, responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -602,7 +597,7 @@ public void testAllowPartialResults() throws InterruptedException { Collections.emptyMap(), Collections.emptyMap(), null, - PipelinedRequest.wrapSearchRequest(request), + request, responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), diff --git a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java index e1bf9244b3a6b..4e351e1424cd0 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -56,7 +56,6 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.sort.SortBuilders; import org.opensearch.test.OpenSearchTestCase; @@ -76,7 +75,6 @@ import static org.hamcrest.Matchers.instanceOf; public class SearchQueryThenFetchAsyncActionTests extends OpenSearchTestCase { - public void testBottomFieldSort() throws Exception { testCase(false, false); } @@ -210,7 +208,7 @@ public void sendExecuteQuery( controller, executor, resultConsumer, - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, null, shardsIter, timeProvider, diff --git a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java index 96ffb016604f9..51d9a06c9ac43 100644 --- a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java @@ -74,7 +74,6 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.sort.SortBuilders; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.transport.MockTransportService; @@ -459,14 +458,14 @@ public void testCCSRemoteReduceMergeFails() throws Exception { SearchRequest searchRequest = new SearchRequest(); searchRequest.preference("null_target"); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference failure = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(r -> fail("no response expected"), failure::set), latch ); TransportSearchAction.ccsRemoteReduce( - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, localIndices, remoteIndicesByCluster, timeProvider, @@ -479,8 +478,8 @@ public void testCCSRemoteReduceMergeFails() throws Exception { if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -515,14 +514,14 @@ public void testCCSRemoteReduce() throws Exception { { SearchRequest searchRequest = new SearchRequest(); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference response = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(response::set, e -> fail("no failures expected")), latch ); TransportSearchAction.ccsRemoteReduce( - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, localIndices, remoteIndicesByCluster, timeProvider, @@ -535,8 +534,8 @@ public void testCCSRemoteReduce() throws Exception { if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -552,14 +551,14 @@ public void testCCSRemoteReduce() throws Exception { SearchRequest searchRequest = new SearchRequest(); searchRequest.preference("index_not_found"); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference failure = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(r -> fail("no response expected"), failure::set), latch ); TransportSearchAction.ccsRemoteReduce( - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, localIndices, remoteIndicesByCluster, timeProvider, @@ -572,8 +571,8 @@ public void testCCSRemoteReduce() throws Exception { if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -610,14 +609,14 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti { SearchRequest searchRequest = new SearchRequest(); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference failure = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(r -> fail("no response expected"), failure::set), latch ); TransportSearchAction.ccsRemoteReduce( - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, localIndices, remoteIndicesByCluster, timeProvider, @@ -630,8 +629,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -650,14 +649,14 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti { SearchRequest searchRequest = new SearchRequest(); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference response = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(response::set, e -> fail("no failures expected")), latch ); TransportSearchAction.ccsRemoteReduce( - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, localIndices, remoteIndicesByCluster, timeProvider, @@ -670,8 +669,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -701,14 +700,14 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti { SearchRequest searchRequest = new SearchRequest(); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference response = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(response::set, e -> fail("no failures expected")), latch ); TransportSearchAction.ccsRemoteReduce( - PipelinedRequest.wrapSearchRequest(searchRequest), + searchRequest, localIndices, remoteIndicesByCluster, timeProvider, @@ -721,8 +720,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index b5f377a029287..d5ccb62e78fac 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -197,13 +197,13 @@ public void testResolveIndexDefaultPipeline() throws Exception { SearchRequest searchRequest = new SearchRequest("my_index").source(SearchSourceBuilder.searchSource().size(5)); PipelinedRequest pipelinedRequest = service.resolvePipeline(searchRequest); assertEquals("p1", pipelinedRequest.getPipeline().getId()); - assertEquals(10, pipelinedRequest.transformedRequest().source().size()); + assertEquals(10, pipelinedRequest.source().size()); // Bypass the default pipeline searchRequest.pipeline("_none"); pipelinedRequest = service.resolvePipeline(searchRequest); assertEquals("_none", pipelinedRequest.getPipeline().getId()); - assertEquals(5, pipelinedRequest.transformedRequest().source().size()); + assertEquals(5, pipelinedRequest.source().size()); } private static abstract class FakeProcessor implements Processor { @@ -584,17 +584,14 @@ public void testTransformRequest() throws Exception { SearchRequest request = new SearchRequest("_index").source(sourceBuilder).pipeline("p1"); PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(request); - SearchRequest transformedRequest = pipelinedRequest.transformedRequest(); - assertEquals(2 * size, transformedRequest.source().size()); + assertEquals(2 * size, pipelinedRequest.source().size()); assertEquals(size, request.source().size()); // This request doesn't specify a pipeline, it doesn't get transformed. request = new SearchRequest("_index").source(sourceBuilder); pipelinedRequest = searchPipelineService.resolvePipeline(request); - SearchRequest notTransformedRequest = pipelinedRequest.transformedRequest(); - assertEquals(size, notTransformedRequest.source().size()); - assertSame(request, notTransformedRequest); + assertEquals(size, pipelinedRequest.source().size()); } public void testTransformResponse() throws Exception { @@ -869,8 +866,7 @@ public void testInlinePipeline() throws Exception { assertEquals(1, pipeline.getSearchResponseProcessors().size()); // Verify that pipeline transforms request - SearchRequest transformedRequest = pipelinedRequest.transformedRequest(); - assertEquals(200, transformedRequest.source().size()); + assertEquals(200, pipelinedRequest.source().size()); int size = 10; SearchHit[] hits = new SearchHit[size]; From b040d3bd01c46989d28414dbc7694c8079283502 Mon Sep 17 00:00:00 2001 From: Michael Froh Date: Fri, 23 Jun 2023 20:42:41 +0000 Subject: [PATCH 05/10] Revert code change from merge conflict Signed-off-by: Michael Froh --- .../org/opensearch/search/pipeline/Pipeline.java | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java index 09dbc6860b5ff..78170be27462d 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java +++ b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java @@ -138,17 +138,7 @@ private static List readProcessors( processors.add(processorFactories.get(type).create(processorFactories, tag, description, config)); } } - return processors; - } - - List flattenAllProcessors() { - List allProcessors = new ArrayList<>( - searchRequestProcessors.size() + searchResponseProcessors.size() + searchPhaseResultsProcessors.size() - ); - allProcessors.addAll(searchRequestProcessors); - allProcessors.addAll(searchPhaseResultsProcessors); - allProcessors.addAll(searchResponseProcessors); - return allProcessors; + return Collections.unmodifiableList(processors); } String getId() { From b8e0b6f45d4e0cc239efc02fec7e0532423f949c Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Sat, 24 Jun 2023 15:28:04 -0700 Subject: [PATCH 06/10] Updated the changelog with more appropiate wording for the change. Signed-off-by: Navneet Verma --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 64e24e6f5d98b..2b161acca5fb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -100,7 +100,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x] ### Added -- [SearchPipeline] Initial code for adding the SearchPhaseInjectorProcessor interface in Search Pipeline.([#7283](https://github.com/opensearch-project/OpenSearch/pull/7283)) +- [SearchPipeline] Add new search pipeline processor type, SearchPhaseResultsProcessor, that can modify the result of one search phase before starting the next phase.([#7283](https://github.com/opensearch-project/OpenSearch/pull/7283)) - Add task cancellation monitoring service ([#7642](https://github.com/opensearch-project/OpenSearch/pull/7642)) - Add TokenManager Interface ([#7452](https://github.com/opensearch-project/OpenSearch/pull/7452)) - Add Remote store as a segment replication source ([#7653](https://github.com/opensearch-project/OpenSearch/pull/7653)) From 9205184312a3d70e0f22e6fdb2b48f173cd9bcd4 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Tue, 27 Jun 2023 16:42:25 -0700 Subject: [PATCH 07/10] Fixed Typos in the code Signed-off-by: Navneet Verma --- .../org/opensearch/plugins/SearchPipelinePlugin.java | 2 +- .../java/org/opensearch/search/pipeline/Pipeline.java | 2 +- .../search/pipeline/SearchPhaseResultsProcessor.java | 11 +++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java b/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java index d7923c22cf14a..948b7790d56bd 100644 --- a/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java +++ b/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java @@ -45,7 +45,7 @@ default Map> getResponseProce } /** - * Returns additional search pipeline search phase injector processor types added by this plugin. + * Returns additional search pipeline search phase results processor types added by this plugin. * * The key of the returned {@link Map} is the unique name for the processor which is specified * in pipeline configurations, and the value is a {@link org.opensearch.search.pipeline.Processor.Factory} diff --git a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java index 78170be27462d..e473e62361b40 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java +++ b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java @@ -218,7 +218,7 @@ SearchPhaseResults runSearchPhaseTran } } return searchPhaseResult; - } catch (Exception e) { + } catch (RuntimeException e) { throw new SearchPipelineProcessingException(e); } } diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java index 90f89a63b566e..4804027455e8c 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java @@ -12,11 +12,22 @@ import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.internal.SearchContext; /** * Creates a processor that runs between Phases of the Search. + * @opensearch.api */ public interface SearchPhaseResultsProcessor extends Processor { + + /** + * Processes the {@link SearchPhaseResults} obtained from a {@link SearchPhase} and to be returned back to the + * next {@link SearchPhase}. + * @param searchPhaseResult {@link SearchPhaseResults} + * @param searchPhaseContext {@link SearchContext} + * @return {@link SearchPhaseResults} + * @param {@link SearchPhaseResult} + */ SearchPhaseResults process( final SearchPhaseResults searchPhaseResult, final SearchPhaseContext searchPhaseContext From 958763a1bfb6024b30fe15cf52480f8434a83b45 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Wed, 28 Jun 2023 10:42:07 -0700 Subject: [PATCH 08/10] Fixing comments relating to return of SearchPhaseResults from processor Signed-off-by: Navneet Verma --- .../search/AbstractSearchAsyncAction.java | 6 +----- .../opensearch/search/pipeline/Pipeline.java | 5 ++--- .../search/pipeline/PipelinedRequest.java | 4 ++-- .../pipeline/SearchPhaseResultsProcessor.java | 7 +++---- .../pipeline/SearchPipelineServiceTests.java | 18 ++++++++++-------- 5 files changed, 18 insertions(+), 22 deletions(-) diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 6e10c5b7dacc2..26d78caa7bd4c 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -699,11 +699,7 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) { final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() final SearchPhase nextPhase = getNextPhase(results, this); if (request instanceof PipelinedRequest && nextPhase != null) { - // From src files the next phase is never null, but from tests this is a possibility. Hence, making sure that - // tests pass, we need to do null check on next phase. - if (nextPhase != null) { - ((PipelinedRequest) request).transformSearchPhase(results, this, this.getName(), nextPhase.getName()); - } + ((PipelinedRequest) request).transformSearchPhase(results, this, this.getName(), nextPhase.getName()); } executeNextPhase(this, nextPhase); } diff --git a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java index e473e62361b40..9162cf97ceda5 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java +++ b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java @@ -203,7 +203,7 @@ SearchResponse transformResponse(SearchRequest request, SearchResponse response) null ); - SearchPhaseResults runSearchPhaseTransformer( + void runSearchPhaseTransformer( SearchPhaseResults searchPhaseResult, SearchPhaseContext context, String currentPhase, @@ -214,10 +214,9 @@ SearchPhaseResults runSearchPhaseTran for (SearchPhaseResultsProcessor searchPhaseResultsProcessor : searchPhaseResultsProcessors) { if (currentPhase.equals(searchPhaseResultsProcessor.getBeforePhase().getName()) && nextPhase.equals(searchPhaseResultsProcessor.getAfterPhase().getName())) { - searchPhaseResult = searchPhaseResultsProcessor.process(searchPhaseResult, context); + searchPhaseResultsProcessor.process(searchPhaseResult, context); } } - return searchPhaseResult; } catch (RuntimeException e) { throw new SearchPipelineProcessingException(e); } diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java index eb5fd4b6c4c26..8d91876f34f11 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java @@ -31,13 +31,13 @@ public SearchResponse transformResponse(SearchResponse response) { return pipeline.transformResponse(this, response); } - public SearchPhaseResults transformSearchPhase( + public void transformSearchPhase( final SearchPhaseResults searchPhaseResult, final SearchPhaseContext searchPhaseContext, final String currentPhase, final String nextPhase ) { - return pipeline.runSearchPhaseTransformer(searchPhaseResult, searchPhaseContext, currentPhase, nextPhase); + pipeline.runSearchPhaseTransformer(searchPhaseResult, searchPhaseContext, currentPhase, nextPhase); } // Visible for testing diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java index 4804027455e8c..783a69114f50f 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java @@ -21,14 +21,13 @@ public interface SearchPhaseResultsProcessor extends Processor { /** - * Processes the {@link SearchPhaseResults} obtained from a {@link SearchPhase} and to be returned back to the - * next {@link SearchPhase}. + * Processes the {@link SearchPhaseResults} obtained from a {@link SearchPhase} which will be returned to next + * {@link SearchPhase}. * @param searchPhaseResult {@link SearchPhaseResults} * @param searchPhaseContext {@link SearchContext} - * @return {@link SearchPhaseResults} * @param {@link SearchPhaseResult} */ - SearchPhaseResults process( + void process( final SearchPhaseResults searchPhaseResult, final SearchPhaseContext searchPhaseContext ); diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index d5ccb62e78fac..04bc4ed1d4699 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -43,6 +43,7 @@ import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.AtomicArray; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.IndexSettings; @@ -277,14 +278,13 @@ public FakeSearchPhaseResultsProcessor( } @Override - public SearchPhaseResults process( + public void process( SearchPhaseResults searchPhaseResult, SearchPhaseContext searchPhaseContext ) { List resultAtomicArray = searchPhaseResult.getAtomicArray().asList(); // updating the maxScore resultAtomicArray.forEach(querySearchResultConsumer); - return searchPhaseResult; } @Override @@ -682,26 +682,27 @@ public void testTransformSearchPhase() { // First try without specifying a pipeline, which should be a no-op. SearchRequest searchRequest = new SearchRequest(); PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(searchRequest); - SearchPhaseResults notTransformedSearchPhaseResults = pipelinedRequest.transformSearchPhase( + AtomicArray notTransformedSearchPhaseResults = searchPhaseResults.getAtomicArray(); + pipelinedRequest.transformSearchPhase( searchPhaseResults, searchPhaseContext, SearchPhase.SearchPhaseName.QUERY.getName(), SearchPhase.SearchPhaseName.FETCH.getName() ); - assertSame(searchPhaseResults, notTransformedSearchPhaseResults); + assertSame(searchPhaseResults.getAtomicArray(), notTransformedSearchPhaseResults); // Now set the pipeline as p1 searchRequest = new SearchRequest().pipeline("p1"); pipelinedRequest = searchPipelineService.resolvePipeline(searchRequest); - SearchPhaseResults transformed = pipelinedRequest.transformSearchPhase( + pipelinedRequest.transformSearchPhase( searchPhaseResults, searchPhaseContext, SearchPhase.SearchPhaseName.QUERY.getName(), SearchPhase.SearchPhaseName.FETCH.getName() ); - List resultAtomicArray = transformed.getAtomicArray().asList(); + List resultAtomicArray = searchPhaseResults.getAtomicArray().asList(); assertEquals(1, resultAtomicArray.size()); // updating the maxScore for (SearchPhaseResult result : resultAtomicArray) { @@ -711,14 +712,15 @@ public void testTransformSearchPhase() { // Check Processor doesn't run for between other phases searchRequest = new SearchRequest().pipeline("p1"); pipelinedRequest = searchPipelineService.resolvePipeline(searchRequest); - SearchPhaseResults notTransformed = pipelinedRequest.transformSearchPhase( + AtomicArray notTransformedSearchPhaseResult = searchPhaseResults.getAtomicArray(); + pipelinedRequest.transformSearchPhase( searchPhaseResults, searchPhaseContext, SearchPhase.SearchPhaseName.DFS_QUERY.getName(), SearchPhase.SearchPhaseName.QUERY.getName() ); - assertSame(searchPhaseResults, notTransformed); + assertSame(searchPhaseResults.getAtomicArray(), notTransformedSearchPhaseResult); } public void testGetPipelines() { From db8b3e055ab4d8a4e9040105a8e1f04aa1507a68 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Wed, 28 Jun 2023 12:22:06 -0700 Subject: [PATCH 09/10] Moved SearchPhaseName enum in separate class and fixed comments. Signed-off-by: Navneet Verma --- .../search/AbstractSearchAsyncAction.java | 2 +- .../opensearch/action/search/SearchPhase.java | 24 +------------ .../action/search/SearchPhaseName.java | 31 +++++++++++++++++ .../search/SearchScrollAsyncAction.java | 2 +- ...SearchScrollQueryThenFetchAsyncAction.java | 2 +- .../plugins/SearchPipelinePlugin.java | 2 +- .../opensearch/search/pipeline/Pipeline.java | 2 +- .../search/pipeline/PipelinedRequest.java | 4 +-- .../pipeline/SearchPhaseResultsProcessor.java | 14 ++++---- .../pipeline/SearchPipelineService.java | 5 ++- .../pipeline/SearchPipelineServiceTests.java | 34 ++++++++++--------- 11 files changed, 68 insertions(+), 54 deletions(-) create mode 100644 server/src/main/java/org/opensearch/action/search/SearchPhaseName.java diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 26d78caa7bd4c..5c03a12984aee 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -699,7 +699,7 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) { final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() final SearchPhase nextPhase = getNextPhase(results, this); if (request instanceof PipelinedRequest && nextPhase != null) { - ((PipelinedRequest) request).transformSearchPhase(results, this, this.getName(), nextPhase.getName()); + ((PipelinedRequest) request).transformSearchPhaseResults(results, this, this.getName(), nextPhase.getName()); } executeNextPhase(this, nextPhase); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhase.java b/server/src/main/java/org/opensearch/action/search/SearchPhase.java index 43dbbc18a30db..50b0cd8e01c1d 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhase.java @@ -42,7 +42,7 @@ * * @opensearch.internal */ -public abstract class SearchPhase implements CheckedRunnable { +abstract class SearchPhase implements CheckedRunnable { private final String name; protected SearchPhase(String name) { @@ -64,26 +64,4 @@ public String getName() { public SearchPhaseName getSearchPhaseName() { return SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT)); } - - /** - * Enum for different Search Phases in OpenSearch - * @opensearch.internal - */ - public enum SearchPhaseName { - QUERY("query"), - FETCH("fetch"), - DFS_QUERY("dfs_query"), - EXPAND("expand"), - CAN_MATCH("can_match"); - - private final String name; - - SearchPhaseName(final String name) { - this.name = name; - } - - public String getName() { - return name; - } - } } diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java new file mode 100644 index 0000000000000..b6f842cf2cce1 --- /dev/null +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java @@ -0,0 +1,31 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +/** + * Enum for different Search Phases in OpenSearch + * @opensearch.internal + */ +public enum SearchPhaseName { + QUERY("query"), + FETCH("fetch"), + DFS_QUERY("dfs_query"), + EXPAND("expand"), + CAN_MATCH("can_match"); + + private final String name; + + SearchPhaseName(final String name) { + this.name = name; + } + + public String getName() { + return name; + } +} diff --git a/server/src/main/java/org/opensearch/action/search/SearchScrollAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchScrollAsyncAction.java index a0164971f3d19..899c7a3c1dabd 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchScrollAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchScrollAsyncAction.java @@ -266,7 +266,7 @@ protected SearchPhase sendResponsePhase( SearchPhaseController.ReducedQueryPhase queryPhase, final AtomicArray fetchResults ) { - return new SearchPhase(SearchPhase.SearchPhaseName.FETCH.getName()) { + return new SearchPhase(SearchPhaseName.FETCH.getName()) { @Override public void run() throws IOException { sendResponse(queryPhase, fetchResults); diff --git a/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java index 51ffeb2ac83bc..9c0721ef63ea6 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java @@ -92,7 +92,7 @@ protected void executeInitialPhase( @Override protected SearchPhase moveToNextPhase(BiFunction clusterNodeLookup) { - return new SearchPhase(SearchPhase.SearchPhaseName.FETCH.getName()) { + return new SearchPhase(SearchPhaseName.FETCH.getName()) { @Override public void run() { final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedScrollQueryPhase( diff --git a/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java b/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java index 948b7790d56bd..3d76bab93a60c 100644 --- a/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java +++ b/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java @@ -51,7 +51,7 @@ default Map> getResponseProce * in pipeline configurations, and the value is a {@link org.opensearch.search.pipeline.Processor.Factory} * to create the processor from a given pipeline configuration. */ - default Map> getPhaseResultsProcessors(Processor.Parameters parameters) { + default Map> getSearchPhaseResultsProcessors(Processor.Parameters parameters) { return Collections.emptyMap(); } } diff --git a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java index 9162cf97ceda5..12cc8f14338c5 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java +++ b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java @@ -203,7 +203,7 @@ SearchResponse transformResponse(SearchRequest request, SearchResponse response) null ); - void runSearchPhaseTransformer( + void runSearchPhaseResultsTransformer( SearchPhaseResults searchPhaseResult, SearchPhaseContext context, String currentPhase, diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java index 8d91876f34f11..5a7539808c127 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java @@ -31,13 +31,13 @@ public SearchResponse transformResponse(SearchResponse response) { return pipeline.transformResponse(this, response); } - public void transformSearchPhase( + public void transformSearchPhaseResults( final SearchPhaseResults searchPhaseResult, final SearchPhaseContext searchPhaseContext, final String currentPhase, final String nextPhase ) { - pipeline.runSearchPhaseTransformer(searchPhaseResult, searchPhaseContext, currentPhase, nextPhase); + pipeline.runSearchPhaseResultsTransformer(searchPhaseResult, searchPhaseContext, currentPhase, nextPhase); } // Visible for testing diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java index 783a69114f50f..772dc8758bace 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java @@ -8,8 +8,8 @@ package org.opensearch.search.pipeline; -import org.opensearch.action.search.SearchPhase; import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.internal.SearchContext; @@ -21,8 +21,8 @@ public interface SearchPhaseResultsProcessor extends Processor { /** - * Processes the {@link SearchPhaseResults} obtained from a {@link SearchPhase} which will be returned to next - * {@link SearchPhase}. + * Processes the {@link SearchPhaseResults} obtained from a SearchPhase which will be returned to next + * SearchPhase. * @param searchPhaseResult {@link SearchPhaseResults} * @param searchPhaseContext {@link SearchContext} * @param {@link SearchPhaseResult} @@ -34,14 +34,14 @@ void process( /** * The phase which should have run before, this processor can start executing. - * @return {@link SearchPhase.SearchPhaseName} + * @return {@link SearchPhaseName} */ - SearchPhase.SearchPhaseName getBeforePhase(); + SearchPhaseName getBeforePhase(); /** * The phase which should run after, this processor execution. - * @return {@link SearchPhase.SearchPhaseName} + * @return {@link SearchPhaseName} */ - SearchPhase.SearchPhaseName getAfterPhase(); + SearchPhaseName getAfterPhase(); } diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java index 1dbc2d0609cfc..29bab3aac6910 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java @@ -114,7 +114,10 @@ public SearchPipelineService( ); this.requestProcessorFactories = processorFactories(searchPipelinePlugins, p -> p.getRequestProcessors(parameters)); this.responseProcessorFactories = processorFactories(searchPipelinePlugins, p -> p.getResponseProcessors(parameters)); - this.phaseInjectorProcessorFactories = processorFactories(searchPipelinePlugins, p -> p.getPhaseResultsProcessors(parameters)); + this.phaseInjectorProcessorFactories = processorFactories( + searchPipelinePlugins, + p -> p.getSearchPhaseResultsProcessors(parameters) + ); putPipelineTaskKey = clusterService.registerClusterManagerTask(ClusterManagerTaskKeys.PUT_SEARCH_PIPELINE_KEY, true); deletePipelineTaskKey = clusterService.registerClusterManagerTask(ClusterManagerTaskKeys.DELETE_SEARCH_PIPELINE_KEY, true); this.isEnabled = isEnabled; diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index 04bc4ed1d4699..6685245748fef 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -21,9 +21,9 @@ import org.opensearch.action.search.MockSearchPhaseContext; import org.opensearch.action.search.PutSearchPipelineRequest; import org.opensearch.action.search.QueryPhaseResultConsumer; -import org.opensearch.action.search.SearchPhase; import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.action.search.SearchPhaseController; +import org.opensearch.action.search.SearchPhaseName; import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.action.search.SearchProgressListener; import org.opensearch.action.search.SearchRequest; @@ -85,7 +85,9 @@ public Map> getResponseProces } @Override - public Map> getPhaseResultsProcessors(Processor.Parameters parameters) { + public Map> getSearchPhaseResultsProcessors( + Processor.Parameters parameters + ) { return Map.of("zoe", (factories, tag, description, config) -> null); } }; @@ -288,13 +290,13 @@ public void process( } @Override - public SearchPhase.SearchPhaseName getBeforePhase() { - return SearchPhase.SearchPhaseName.QUERY; + public SearchPhaseName getBeforePhase() { + return SearchPhaseName.QUERY; } @Override - public SearchPhase.SearchPhaseName getAfterPhase() { - return SearchPhase.SearchPhaseName.FETCH; + public SearchPhaseName getAfterPhase() { + return SearchPhaseName.FETCH; } } @@ -361,7 +363,7 @@ public Map> getResponseProces } @Override - public Map> getPhaseResultsProcessors( + public Map> getSearchPhaseResultsProcessors( Processor.Parameters parameters ) { return phaseProcessors; @@ -683,11 +685,11 @@ public void testTransformSearchPhase() { SearchRequest searchRequest = new SearchRequest(); PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(searchRequest); AtomicArray notTransformedSearchPhaseResults = searchPhaseResults.getAtomicArray(); - pipelinedRequest.transformSearchPhase( + pipelinedRequest.transformSearchPhaseResults( searchPhaseResults, searchPhaseContext, - SearchPhase.SearchPhaseName.QUERY.getName(), - SearchPhase.SearchPhaseName.FETCH.getName() + SearchPhaseName.QUERY.getName(), + SearchPhaseName.FETCH.getName() ); assertSame(searchPhaseResults.getAtomicArray(), notTransformedSearchPhaseResults); @@ -695,11 +697,11 @@ public void testTransformSearchPhase() { searchRequest = new SearchRequest().pipeline("p1"); pipelinedRequest = searchPipelineService.resolvePipeline(searchRequest); - pipelinedRequest.transformSearchPhase( + pipelinedRequest.transformSearchPhaseResults( searchPhaseResults, searchPhaseContext, - SearchPhase.SearchPhaseName.QUERY.getName(), - SearchPhase.SearchPhaseName.FETCH.getName() + SearchPhaseName.QUERY.getName(), + SearchPhaseName.FETCH.getName() ); List resultAtomicArray = searchPhaseResults.getAtomicArray().asList(); @@ -713,11 +715,11 @@ public void testTransformSearchPhase() { searchRequest = new SearchRequest().pipeline("p1"); pipelinedRequest = searchPipelineService.resolvePipeline(searchRequest); AtomicArray notTransformedSearchPhaseResult = searchPhaseResults.getAtomicArray(); - pipelinedRequest.transformSearchPhase( + pipelinedRequest.transformSearchPhaseResults( searchPhaseResults, searchPhaseContext, - SearchPhase.SearchPhaseName.DFS_QUERY.getName(), - SearchPhase.SearchPhaseName.QUERY.getName() + SearchPhaseName.DFS_QUERY.getName(), + SearchPhaseName.QUERY.getName() ); assertSame(searchPhaseResults.getAtomicArray(), notTransformedSearchPhaseResult); From b35858d3f95999800c14d538f35f8121b70385a1 Mon Sep 17 00:00:00 2001 From: Michael Froh Date: Thu, 29 Jun 2023 00:17:38 +0000 Subject: [PATCH 10/10] Resolve remaining merge conflict Signed-off-by: Michael Froh --- .../search/pipeline/PipelineWithMetrics.java | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelineWithMetrics.java b/server/src/main/java/org/opensearch/search/pipeline/PipelineWithMetrics.java index 662473f190006..612e979e56070 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/PipelineWithMetrics.java +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelineWithMetrics.java @@ -43,12 +43,22 @@ class PipelineWithMetrics extends Pipeline { Integer version, List requestProcessors, List responseProcessors, + List phaseResultsProcessors, NamedWriteableRegistry namedWriteableRegistry, OperationMetrics totalRequestMetrics, OperationMetrics totalResponseMetrics, LongSupplier relativeTimeSupplier ) { - super(id, description, version, requestProcessors, responseProcessors, namedWriteableRegistry, relativeTimeSupplier); + super( + id, + description, + version, + requestProcessors, + responseProcessors, + phaseResultsProcessors, + namedWriteableRegistry, + relativeTimeSupplier + ); this.totalRequestMetrics = totalRequestMetrics; this.totalResponseMetrics = totalResponseMetrics; for (Processor requestProcessor : getSearchRequestProcessors()) { @@ -64,6 +74,7 @@ static PipelineWithMetrics create( Map config, Map> requestProcessorFactories, Map> responseProcessorFactories, + Map> phaseResultsProcessorFactories, NamedWriteableRegistry namedWriteableRegistry, OperationMetrics totalRequestProcessingMetrics, OperationMetrics totalResponseProcessingMetrics @@ -79,6 +90,16 @@ static PipelineWithMetrics create( RESPONSE_PROCESSORS_KEY ); List responseProcessors = readProcessors(responseProcessorFactories, responseProcessorConfigs); + List> phaseResultsProcessorConfigs = ConfigurationUtils.readOptionalList( + null, + null, + config, + PHASE_PROCESSORS_KEY + ); + List phaseResultsProcessors = readProcessors( + phaseResultsProcessorFactories, + phaseResultsProcessorConfigs + ); if (config.isEmpty() == false) { throw new OpenSearchParseException( "pipeline [" @@ -93,6 +114,7 @@ static PipelineWithMetrics create( version, requestProcessors, responseProcessors, + phaseResultsProcessors, namedWriteableRegistry, totalRequestProcessingMetrics, totalResponseProcessingMetrics,