Skip to content

Commit

Permalink
update get api
Browse files Browse the repository at this point in the history
  • Loading branch information
ruai0511 committed Feb 11, 2025
1 parent 3865950 commit f954f61
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,33 @@
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest;
import org.opensearch.wlm.Rule;
import org.opensearch.wlm.Rule.Builder;
import org.opensearch.wlm.Rule.RuleAttribute;
import org.opensearch.common.UUIDs;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import java.util.Map;

/**
* A request for get Rule
* @opensearch.experimental
*/
public class GetRuleRequest extends ClusterManagerNodeRequest<GetRuleRequest> {
private final String _id;
private final Map<RuleAttribute, Set<String>> attributeFilters;

/**
* Constructor for GetRuleRequest
* @param _id - Rule _id that we want to get
* @param attributeFilters - Attributes that we want to filter on
*/
public GetRuleRequest(String _id) {
public GetRuleRequest(String _id, Map<RuleAttribute, Set<String>> attributeFilters) {
this._id = _id;
this.attributeFilters = attributeFilters;
}

/**
Expand All @@ -42,6 +48,7 @@ public GetRuleRequest(String _id) {
public GetRuleRequest(StreamInput in) throws IOException {
super(in);
_id = in.readOptionalString();
attributeFilters = in.readMap((i) -> RuleAttribute.fromName(i.readString()), i -> new HashSet<>(i.readStringList()));
}

@Override
Expand All @@ -53,6 +60,7 @@ public ActionRequestValidationException validate() {
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(_id);
out.writeMap(attributeFilters, RuleAttribute::writeTo, StreamOutput::writeStringCollection);
}

/**
Expand All @@ -61,4 +69,11 @@ public void writeTo(StreamOutput out) throws IOException {
public String get_id() {
return _id;
}

/**
* attributeFilters getter
*/
public Map<RuleAttribute, Set<String>> getAttributeFilters() {
return attributeFilters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.List;
import java.util.Map;

import static org.opensearch.wlm.Rule._ID_STRING;

/**
* Response for the get API for Rule
* @opensearch.experimental
Expand Down Expand Up @@ -61,7 +63,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
builder.startArray("rules");
for (Map.Entry<String, Rule> entry : rules.entrySet()) {
entry.getValue().toXContent(builder, new MapParams(Map.of("_id", entry.getKey())));
entry.getValue().toXContent(builder, new MapParams(Map.of(_ID_STRING, entry.getKey())));
}
builder.endArray();
builder.endObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ public TransportGetRuleAction(

@Override
protected void doExecute(Task task, GetRuleRequest request, ActionListener<GetRuleResponse> listener) {
rulePersistenceService.getRule(request.get_id(), listener);
rulePersistenceService.getRule(request.get_id(), request.getAttributeFilters(), listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,20 @@
import org.opensearch.plugin.wlm.rule.action.*;
import org.opensearch.rest.*;
import org.opensearch.rest.action.RestResponseListener;
import org.opensearch.telemetry.tracing.AttributeNames;
import org.opensearch.wlm.Rule;
import org.opensearch.wlm.Rule.RuleAttribute;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Arrays;
import java.util.HashSet;

import static org.opensearch.rest.RestRequest.Method.*;
import static org.opensearch.rest.RestRequest.Method.GET;
import static org.opensearch.wlm.Rule._ID_STRING;

/**
* Rest action to get a Rule
Expand Down Expand Up @@ -49,7 +58,15 @@ public List<Route> routes() {

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
final GetRuleRequest getRuleRequest = new GetRuleRequest(request.param("_id"));
Map<RuleAttribute, Set<String>> attributeFilters = new HashMap<>();
for (String attributeName : request.params().keySet()) {
if (attributeName.equals(_ID_STRING)) {
continue;
}
String[] valuesArray = request.param(attributeName).split(",");
attributeFilters.put(RuleAttribute.fromName(attributeName), new HashSet<>(Arrays.asList(valuesArray)));
}
final GetRuleRequest getRuleRequest = new GetRuleRequest(request.param(_ID_STRING), attributeFilters);
return channel -> client.execute(GetRuleAction.INSTANCE, getRuleRequest, getRuleResponse(channel));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ResourceAlreadyExistsException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.wlm.Rule;
import org.opensearch.wlm.Rule.RuleAttribute;
import org.opensearch.wlm.Rule.Builder;
import org.opensearch.common.inject.Inject;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.XContentFactory;
Expand All @@ -27,7 +34,7 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.Objects;
import java.util.Map;
import java.util.stream.Collectors;
Expand All @@ -39,6 +46,7 @@
public class RulePersistenceService {
public static final String RULE_INDEX = ".rule";
private final Client client;
private final ClusterService clusterService;
private static final Logger logger = LogManager.getLogger(RulePersistenceService.class);

/**
Expand All @@ -47,13 +55,29 @@ public class RulePersistenceService {
*/
@Inject
public RulePersistenceService(
final ClusterService clusterService,
final Client client
) {
this.clusterService = clusterService;
this.client = client;
}

public void createRule(Rule rule, ActionListener<CreateRuleResponse> listener) {
try {
final Map<String, Object> indexSettings = Map.of("index.number_of_shards", 1, "index.auto_expand_replicas", "0-all");
createIfAbsent(RULE_INDEX, indexSettings, new ActionListener<>() {
@Override
public void onResponse(Boolean indexCreated) {
persistRule(rule, listener);
}
@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
});
}

public void persistRule(Rule rule, ActionListener<CreateRuleResponse> listener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
IndexRequest indexRequest = new IndexRequest(RULE_INDEX)
.source(rule.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));

Expand All @@ -68,16 +92,42 @@ public void createRule(Rule rule, ActionListener<CreateRuleResponse> listener) {
}
));
} catch (IOException e) {
logger.error("Error saving rule to index: " + RULE_INDEX, e);
logger.error("Error saving rule to index: {}", RULE_INDEX, e);
listener.onFailure(new RuntimeException("Failed to save rule to index."));
}
}

public void getRule(String id, ActionListener<GetRuleResponse> listener) {
private void createIfAbsent(String indexName, Map<String, Object> indexSettings, ActionListener<Boolean> listener) {
if (clusterService.state().metadata().hasIndex(indexName)) {
listener.onResponse(true);
return;
}
CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName).settings(indexSettings);
client.admin().indices().create(createIndexRequest, new ActionListener<>() {
@Override
public void onResponse(CreateIndexResponse response) {
logger.info("Index {} created?: {}", indexName, response.isAcknowledged());
listener.onResponse(response.isAcknowledged());
}

@Override
public void onFailure(Exception e) {
if (e instanceof ResourceAlreadyExistsException) {
logger.info("Index {} already exists", indexName);
listener.onResponse(true);
} else {
logger.error("Failed to create index {}: {}", indexName, e.getMessage());
listener.onFailure(e);
}
}
});
}

public void getRule(String id, Map<RuleAttribute, Set<String>> attributeFilters, ActionListener<GetRuleResponse> listener) {
if (id != null) {
fetchRuleById(id, listener);
} else {
fetchAllRules(listener);
fetchAllRules(attributeFilters, listener);
}
}

Expand All @@ -93,11 +143,10 @@ private void fetchRuleById(String id, ActionListener<GetRuleResponse> listener)

private void handleGetOneRuleResponse(String id, GetResponse getResponse, ActionListener<GetRuleResponse> listener) {
if (getResponse.isExists()) {
try {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
XContentParser parser = MediaTypeRegistry.JSON.xContent()
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, getResponse.getSourceAsString());
Rule rule = Rule.Builder.fromXContent(parser).build();
listener.onResponse(new GetRuleResponse(Map.of(id, rule), RestStatus.OK));
listener.onResponse(new GetRuleResponse(Map.of(id, Builder.fromXContent(parser).build()), RestStatus.OK));
} catch (IOException e) {
logger.error("Error parsing rule with ID {}: {}", id, e.getMessage());
listener.onFailure(e);
Expand All @@ -107,34 +156,59 @@ private void handleGetOneRuleResponse(String id, GetResponse getResponse, Action
}
}

private void fetchAllRules(ActionListener<GetRuleResponse> listener) {
private boolean matchesFilters(Rule rule, Map<RuleAttribute, Set<String>> attributeFilters) {
for (Map.Entry<RuleAttribute, Set<String>> entry : attributeFilters.entrySet()) {
RuleAttribute attribute = entry.getKey();
Set<String> expectedValues = entry.getValue();
Set<String> ruleValues = rule.getAttributeMap().get(attribute);
if (ruleValues == null || ruleValues.stream().noneMatch(expectedValues::contains)) {
return false;
}
}
return true;
}


private void fetchAllRules(Map<RuleAttribute, Set<String>> attributeFilters, ActionListener<GetRuleResponse> listener) {
client.prepareSearch(RULE_INDEX)
.setQuery(QueryBuilders.matchAllQuery())
.setSize(20)
.execute(ActionListener.wrap(
searchResponse -> handleGetAllRuleResponse(searchResponse, listener),
searchResponse -> handleGetAllRuleResponse(searchResponse, attributeFilters, listener),
e -> {
logger.error("Failed to fetch all rules: {}", e.getMessage());
listener.onFailure(e);
}
));
}

private void handleGetAllRuleResponse(SearchResponse searchResponse, ActionListener<GetRuleResponse> listener) {
private void handleGetAllRuleResponse(SearchResponse searchResponse, Map<RuleAttribute, Set<String>> attributeFilters, ActionListener<GetRuleResponse> listener) {
Map<String, Rule> ruleMap = Arrays.stream(searchResponse.getHits().getHits())
.map(hit -> {
try {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
XContentParser parser = MediaTypeRegistry.JSON.xContent()
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, hit.getSourceAsString());
return Map.entry(hit.getId(), Rule.Builder.fromXContent(parser).build());
Rule currRule = Rule.Builder.fromXContent(parser).build();
if (matchesFilters(currRule,attributeFilters)) {
return Map.entry(hit.getId(), currRule);
}
return null;
} catch (IOException e) {
logger.error("Failed to parse rule from hit: {}", e.getMessage());
listener.onFailure(e);
return null;
}
})
.filter(Objects::nonNull)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

listener.onResponse(new GetRuleResponse(ruleMap, RestStatus.OK));
}

public Client getClient() {
return client;
}

public ClusterService getClusterService() {
return clusterService;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@

package org.opensearch.plugin.wlm;

import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.plugin.wlm.rule.service.RulePersistenceService;
import org.opensearch.plugin.wlm.rule.service.RulePersistenceServiceTests;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.wlm.Rule;
import org.opensearch.wlm.Rule.RuleAttribute;

Expand All @@ -21,6 +28,8 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.wlm.Rule._ID_STRING;
import static org.opensearch.wlm.Rule.builder;

Expand All @@ -29,20 +38,22 @@ public class RuleTestUtils {
public static final String _ID_TWO = "G5iIq84j7eK1qIAAAAIH53=1";
public static final String LABEL_ONE = "label_one";
public static final String LABEL_TWO = "label_two";
public static final String PATTERN_ONE = "pattern_1";
public static final String PATTERN_TWO = "pattern_2";
public static final String QUERY_GROUP = "query_group";
public static final String TIMESTAMP_ONE = "2024-01-26T08:58:57.558Z";
public static final String TIMESTAMP_TWO = "2023-01-26T08:58:57.558Z";
public static final Rule ruleOne = builder()
.feature(QUERY_GROUP)
.label(LABEL_ONE)
.attributeMap(Map.of(RuleAttribute.INDEX_PATTERN, Set.of("pattern_1")))
.attributeMap(Map.of(RuleAttribute.INDEX_PATTERN, Set.of(PATTERN_ONE)))
.updatedAt(TIMESTAMP_ONE)
.build();

public static final Rule ruleTwo = builder()
.feature(QUERY_GROUP)
.label(LABEL_TWO)
.attributeMap(Map.of(RuleAttribute.INDEX_PATTERN, Set.of("pattern_2", "pattern_3")))
.attributeMap(Map.of(RuleAttribute.INDEX_PATTERN, Set.of(PATTERN_TWO)))
.updatedAt(TIMESTAMP_TWO)
.build();

Expand All @@ -53,6 +64,16 @@ public static Map<String, Rule> ruleMap() {
);
}

public static RulePersistenceService setUpRulePersistenceService() {
Client client = mock(Client.class);
ClusterService clusterService = mock(ClusterService.class);
ThreadPool threadPool = mock(ThreadPool.class);
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
return new RulePersistenceService(clusterService, client);
}

public static void assertEqualRules(
Map<String, Rule> mapOne,
Map<String, Rule> mapTwo,
Expand Down
Loading

0 comments on commit f954f61

Please sign in to comment.