Skip to content

Commit

Permalink
Adding the SearchPhaseResultsProcessor interface in Search Pipeline (#…
Browse files Browse the repository at this point in the history
…7283) (#8512)

Add new search pipeline processor type, SearchPhaseResultsProcessor, that can modify the result of one search phase before starting the next phase.

Along with this, added the code to resolve the Search pipeline once and added new SearchRequest type PipelinedRequest.

Backport of PR: #7283

---------

Signed-off-by: Navneet Verma <navneev@amazon.com>
Co-authored-by: Michael Froh <froh@amazon.com>
Co-authored-by: Andrew Ross <andrross@amazon.com>
  • Loading branch information
3 people authored Jul 7, 2023
1 parent 94b1d48 commit 56a24bb
Show file tree
Hide file tree
Showing 22 changed files with 402 additions and 46 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x]
### Added
- [SearchPipeline] Add new search pipeline processor type, SearchPhaseResultsProcessor, that can modify the result of one search phase before starting the next phase.([#7283](https://github.com/opensearch-project/OpenSearch/pull/7283))
- Add task cancellation monitoring service ([#7642](https://github.com/opensearch-project/OpenSearch/pull/7642))
- Add TokenManager Interface ([#7452](https://github.com/opensearch-project/OpenSearch/pull/7452))
- Add Remote store as a segment replication source ([#7653](https://github.com/opensearch-project/OpenSearch/pull/7653))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ teardown:
{
"script" : {
"lang" : "painless",
"source" : "ctx._source['size'] += 10; ctx._source['from'] -= 1; ctx._source['explain'] = !ctx._source['explain']; ctx._source['version'] = !ctx._source['version']; ctx._source['seq_no_primary_term'] = !ctx._source['seq_no_primary_term']; ctx._source['track_scores'] = !ctx._source['track_scores']; ctx._source['track_total_hits'] = 1; ctx._source['min_score'] -= 0.9; ctx._source['terminate_after'] += 2; ctx._source['profile'] = !ctx._source['profile'];"
"source" : "ctx._source['size'] += 10; ctx._source['from'] = ctx._source['from'] <= 0 ? ctx._source['from'] : ctx._source['from'] - 1 ; ctx._source['explain'] = !ctx._source['explain']; ctx._source['version'] = !ctx._source['version']; ctx._source['seq_no_primary_term'] = !ctx._source['seq_no_primary_term']; ctx._source['track_scores'] = !ctx._source['track_scores']; ctx._source['track_total_hits'] = 1; ctx._source['min_score'] -= 0.9; ctx._source['terminate_after'] += 2; ctx._source['profile'] = !ctx._source['profile'];"
}
}
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.pipeline.PipelinedRequest;
import org.opensearch.transport.Transport;

import java.util.ArrayDeque;
Expand Down Expand Up @@ -696,7 +697,11 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) {
* @see #onShardResult(SearchPhaseResult, SearchShardIterator)
*/
final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim()
executeNextPhase(this, getNextPhase(results, this));
final SearchPhase nextPhase = getNextPhase(results, this);
if (request instanceof PipelinedRequest && nextPhase != null) {
((PipelinedRequest) request).transformSearchPhaseResults(results, this, this.getName(), nextPhase.getName());
}
executeNextPhase(this, nextPhase);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ boolean hasResult(int shardIndex) {
}

@Override
AtomicArray<Result> getAtomicArray() {
public AtomicArray<Result> getAtomicArray() {
return results;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMa
) {
// We set max concurrent shard requests to the number of shards so no throttling happens for can_match requests
super(
"can_match",
SearchPhaseName.CAN_MATCH.getName(),
logger,
searchTransportService,
nodeIdToConnection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ final class DfsQueryPhase extends SearchPhase {
Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
SearchPhaseContext context
) {
super("dfs_query");
super(SearchPhaseName.DFS_QUERY.getName());
this.progressListener = context.getTask().getProgressListener();
this.queryResult = queryResult;
this.searchResults = searchResults;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ final class ExpandSearchPhase extends SearchPhase {
private final AtomicArray<SearchPhaseResult> queryResults;

ExpandSearchPhase(SearchPhaseContext context, InternalSearchResponse searchResponse, AtomicArray<SearchPhaseResult> queryResults) {
super("expand");
super(SearchPhaseName.EXPAND.getName());
this.context = context;
this.searchResponse = searchResponse;
this.queryResults = queryResults;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ final class FetchSearchPhase extends SearchPhase {
SearchPhaseContext context,
BiFunction<InternalSearchResponse, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory
) {
super("fetch");
super(SearchPhaseName.FETCH.getName());
if (context.getNumShards() != resultConsumer.getNumShards()) {
throw new IllegalStateException(
"number of shards must match the length of the query results but doesn't:"
Expand Down
10 changes: 10 additions & 0 deletions server/src/main/java/org/opensearch/action/search/SearchPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.common.CheckedRunnable;

import java.io.IOException;
import java.util.Locale;
import java.util.Objects;

/**
Expand All @@ -54,4 +55,13 @@ protected SearchPhase(String name) {
public String getName() {
return name;
}

/**
* Returns the SearchPhase name as {@link SearchPhaseName}. Exception will come if SearchPhase name is not defined
* in {@link SearchPhaseName}
* @return {@link SearchPhaseName}
*/
public SearchPhaseName getSearchPhaseName() {
return SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
*
* @opensearch.internal
*/
interface SearchPhaseContext extends Executor {
public interface SearchPhaseContext extends Executor {
// TODO maybe we can make this concrete later - for now we just implement this in the base class for all initial phases

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.action.search;

/**
* Enum for different Search Phases in OpenSearch
* @opensearch.internal
*/
public enum SearchPhaseName {
QUERY("query"),
FETCH("fetch"),
DFS_QUERY("dfs_query"),
EXPAND("expand"),
CAN_MATCH("can_match");

private final String name;

SearchPhaseName(final String name) {
this.name = name;
}

public String getName() {
return name;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
*
* @opensearch.internal
*/
abstract class SearchPhaseResults<Result extends SearchPhaseResult> {
public abstract class SearchPhaseResults<Result extends SearchPhaseResult> {
private final int numShards;

SearchPhaseResults(int numShards) {
Expand Down Expand Up @@ -75,7 +75,13 @@ final int getNumShards() {

void consumeShardFailure(int shardIndex) {}

AtomicArray<Result> getAtomicArray() {
/**
* Returns an {@link AtomicArray} of {@link Result}, which are nothing but the SearchPhaseResults
* for shards. The {@link Result} are of type {@link SearchPhaseResult}
*
* @return an {@link AtomicArray} of {@link Result}
*/
public AtomicArray<Result> getAtomicArray() {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ protected SearchPhase sendResponsePhase(
SearchPhaseController.ReducedQueryPhase queryPhase,
final AtomicArray<? extends SearchPhaseResult> fetchResults
) {
return new SearchPhase("fetch") {
return new SearchPhase(SearchPhaseName.FETCH.getName()) {
@Override
public void run() throws IOException {
sendResponse(queryPhase, fetchResults);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ protected void executeInitialPhase(

@Override
protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
return new SearchPhase("fetch") {
return new SearchPhase(SearchPhaseName.FETCH.getName()) {
@Override
public void run() {
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedScrollQueryPhase(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,12 @@ private void executeRequest(
relativeStartNanos,
System::nanoTime
);
SearchRequest searchRequest;
PipelinedRequest searchRequest;
ActionListener<SearchResponse> listener;
try {
PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(originalSearchRequest);
searchRequest = pipelinedRequest.transformedRequest();
searchRequest = searchPipelineService.resolvePipeline(originalSearchRequest);
listener = ActionListener.wrap(
r -> originalListener.onResponse(pipelinedRequest.transformResponse(r)),
r -> originalListener.onResponse(searchRequest.transformResponse(r)),
originalListener::onFailure
);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.plugins;

import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

Expand Down Expand Up @@ -42,4 +43,15 @@ default Map<String, Processor.Factory<SearchRequestProcessor>> getRequestProcess
default Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Processor.Parameters parameters) {
return Collections.emptyMap();
}

/**
* Returns additional search pipeline search phase results processor types added by this plugin.
*
* The key of the returned {@link Map} is the unique name for the processor which is specified
* in pipeline configurations, and the value is a {@link org.opensearch.search.pipeline.Processor.Factory}
* to create the processor from a given pipeline configuration.
*/
default Map<String, Processor.Factory<SearchPhaseResultsProcessor>> getSearchPhaseResultsProcessors(Processor.Parameters parameters) {
return Collections.emptyMap();
}
}
33 changes: 31 additions & 2 deletions server/src/main/java/org/opensearch/search/pipeline/Pipeline.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@

package org.opensearch.search.pipeline;

import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.Nullable;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.search.SearchPhaseResult;

import java.util.Collections;
import java.util.List;
Expand All @@ -28,6 +31,7 @@ class Pipeline {

public static final String REQUEST_PROCESSORS_KEY = "request_processors";
public static final String RESPONSE_PROCESSORS_KEY = "response_processors";
public static final String PHASE_PROCESSORS_KEY = "phase_results_processors";
private final String id;
private final String description;
private final Integer version;
Expand All @@ -36,7 +40,7 @@ class Pipeline {
// Then these can be CompoundProcessors instead of lists.
private final List<SearchRequestProcessor> searchRequestProcessors;
private final List<SearchResponseProcessor> searchResponseProcessors;

private final List<SearchPhaseResultsProcessor> searchPhaseResultsProcessors;
private final NamedWriteableRegistry namedWriteableRegistry;
private final LongSupplier relativeTimeSupplier;

Expand All @@ -46,6 +50,7 @@ class Pipeline {
@Nullable Integer version,
List<SearchRequestProcessor> requestProcessors,
List<SearchResponseProcessor> responseProcessors,
List<SearchPhaseResultsProcessor> phaseResultsProcessors,
NamedWriteableRegistry namedWriteableRegistry,
LongSupplier relativeTimeSupplier
) {
Expand All @@ -54,6 +59,7 @@ class Pipeline {
this.version = version;
this.searchRequestProcessors = Collections.unmodifiableList(requestProcessors);
this.searchResponseProcessors = Collections.unmodifiableList(responseProcessors);
this.searchPhaseResultsProcessors = Collections.unmodifiableList(phaseResultsProcessors);
this.namedWriteableRegistry = namedWriteableRegistry;
this.relativeTimeSupplier = relativeTimeSupplier;
}
Expand All @@ -78,6 +84,10 @@ List<SearchResponseProcessor> getSearchResponseProcessors() {
return searchResponseProcessors;
}

List<SearchPhaseResultsProcessor> getSearchPhaseResultsProcessors() {
return searchPhaseResultsProcessors;
}

protected void beforeTransformRequest() {}

protected void afterTransformRequest(long timeInNanos) {}
Expand Down Expand Up @@ -168,14 +178,33 @@ SearchResponse transformResponse(SearchRequest request, SearchResponse response)
return response;
}

<Result extends SearchPhaseResult> void runSearchPhaseResultsTransformer(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext context,
String currentPhase,
String nextPhase
) throws SearchPipelineProcessingException {

try {
for (SearchPhaseResultsProcessor searchPhaseResultsProcessor : searchPhaseResultsProcessors) {
if (currentPhase.equals(searchPhaseResultsProcessor.getBeforePhase().getName())
&& nextPhase.equals(searchPhaseResultsProcessor.getAfterPhase().getName())) {
searchPhaseResultsProcessor.process(searchPhaseResult, context);
}
}
} catch (RuntimeException e) {
throw new SearchPipelineProcessingException(e);
}
}

static final Pipeline NO_OP_PIPELINE = new Pipeline(
SearchPipelineService.NOOP_PIPELINE_ID,
"Pipeline that does not transform anything",
0,
Collections.emptyList(),
Collections.emptyList(),
Collections.emptyList(),
null,
() -> 0L
);

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,22 @@ class PipelineWithMetrics extends Pipeline {
Integer version,
List<SearchRequestProcessor> requestProcessors,
List<SearchResponseProcessor> responseProcessors,
List<SearchPhaseResultsProcessor> phaseResultsProcessors,
NamedWriteableRegistry namedWriteableRegistry,
OperationMetrics totalRequestMetrics,
OperationMetrics totalResponseMetrics,
LongSupplier relativeTimeSupplier
) {
super(id, description, version, requestProcessors, responseProcessors, namedWriteableRegistry, relativeTimeSupplier);
super(
id,
description,
version,
requestProcessors,
responseProcessors,
phaseResultsProcessors,
namedWriteableRegistry,
relativeTimeSupplier
);
this.totalRequestMetrics = totalRequestMetrics;
this.totalResponseMetrics = totalResponseMetrics;
for (Processor requestProcessor : getSearchRequestProcessors()) {
Expand All @@ -64,6 +74,7 @@ static PipelineWithMetrics create(
Map<String, Object> config,
Map<String, Processor.Factory<SearchRequestProcessor>> requestProcessorFactories,
Map<String, Processor.Factory<SearchResponseProcessor>> responseProcessorFactories,
Map<String, Processor.Factory<SearchPhaseResultsProcessor>> phaseResultsProcessorFactories,
NamedWriteableRegistry namedWriteableRegistry,
OperationMetrics totalRequestProcessingMetrics,
OperationMetrics totalResponseProcessingMetrics
Expand All @@ -79,6 +90,16 @@ static PipelineWithMetrics create(
RESPONSE_PROCESSORS_KEY
);
List<SearchResponseProcessor> responseProcessors = readProcessors(responseProcessorFactories, responseProcessorConfigs);
List<Map<String, Object>> phaseResultsProcessorConfigs = ConfigurationUtils.readOptionalList(
null,
null,
config,
PHASE_PROCESSORS_KEY
);
List<SearchPhaseResultsProcessor> phaseResultsProcessors = readProcessors(
phaseResultsProcessorFactories,
phaseResultsProcessorConfigs
);
if (config.isEmpty() == false) {
throw new OpenSearchParseException(
"pipeline ["
Expand All @@ -93,6 +114,7 @@ static PipelineWithMetrics create(
version,
requestProcessors,
responseProcessors,
phaseResultsProcessors,
namedWriteableRegistry,
totalRequestProcessingMetrics,
totalResponseProcessingMetrics,
Expand Down
Loading

0 comments on commit 56a24bb

Please sign in to comment.