Skip to content

Commit

Permalink
filtering stats api for hidden model (#2307) (#2328)
Browse files Browse the repository at this point in the history
* filtering stats api for hidden model

Signed-off-by: Dhrubo Saha <[email protected]>

* addressing comments

Signed-off-by: Dhrubo Saha <[email protected]>

* reverting to excludes instead of includes

Signed-off-by: Dhrubo Saha <[email protected]>

* addressing comments

Signed-off-by: Dhrubo Saha <[email protected]>

* spotlessApply

Signed-off-by: Dhrubo Saha <[email protected]>

* renamed method name

Signed-off-by: Dhrubo Saha <[email protected]>

* addressing comments

Signed-off-by: Dhrubo Saha <[email protected]>

* spotlessApply

Signed-off-by: Dhrubo Saha <[email protected]>

* fixing a test

Signed-off-by: Dhrubo Saha <[email protected]>

* addressed comments

Signed-off-by: Dhrubo Saha <[email protected]>

---------

Signed-off-by: Dhrubo Saha <[email protected]>
(cherry picked from commit a83a78f)

Co-authored-by: Dhrubo Saha <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and dhrubo-os authored Apr 16, 2024
1 parent 9c61432 commit 3cc2fca
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,31 @@

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.opensearch.action.FailedNodeException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.env.Environment;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLActionStats;
import org.opensearch.ml.stats.MLAlgoStats;
Expand All @@ -26,15 +40,26 @@
import org.opensearch.ml.stats.MLStatLevel;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.stats.MLStatsInput;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.monitor.jvm.JvmService;
import org.opensearch.search.SearchHit;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.google.common.annotations.VisibleForTesting;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class MLStatsNodesTransportAction extends
TransportNodesAction<MLStatsNodesRequest, MLStatsNodesResponse, MLStatsNodeRequest, MLStatsNodeResponse> {
private MLStats mlStats;
private final JvmService jvmService;

private final Client client;

private final MLModelManager mlModelManager;

/**
* Constructor
*
Expand All @@ -52,7 +77,9 @@ public MLStatsNodesTransportAction(
TransportService transportService,
ActionFilters actionFilters,
MLStats mlStats,
Environment environment
Environment environment,
Client client,
MLModelManager mlModelManager
) {
super(
MLStatsNodesAction.NAME,
Expand All @@ -67,6 +94,8 @@ public MLStatsNodesTransportAction(
);
this.mlStats = mlStats;
this.jvmService = new JvmService(environment.settings());
this.client = client;
this.mlModelManager = mlModelManager;
}

@Override
Expand Down Expand Up @@ -127,21 +156,88 @@ MLStatsNodeResponse createMLStatsNodeResponse(MLStatsNodesRequest mlStatsNodesRe
}

Map<String, MLModelStats> modelStats = new HashMap<>();
// return model level stats
if (mlStatsInput.includeModelStats()) {
for (String modelId : mlStats.getAllModels()) {
if (mlStatsInput.retrieveStatsForModel(modelId)) {
Map<ActionName, MLActionStats> actionStatsMap = new HashMap<>();
for (Map.Entry<ActionName, MLActionStats> entry : mlStats.getModelStats(modelId).entrySet()) {
if (mlStatsInput.retrieveStatsForAction(entry.getKey())) {
actionStatsMap.put(entry.getKey(), entry.getValue());
CountDownLatch latch = new CountDownLatch(1);
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);
searchHiddenModels(ActionListener.wrap(hiddenModels -> {
for (String modelId : mlStats.getAllModels()) {
if (isSuperAdmin || !hiddenModels.contains(modelId)) {
if (mlStatsInput.retrieveStatsForModel(modelId)) {
Map<ActionName, MLActionStats> actionStatsMap = new HashMap<>();
for (Map.Entry<ActionName, MLActionStats> entry : mlStats.getModelStats(modelId).entrySet()) {
if (mlStatsInput.retrieveStatsForAction(entry.getKey())) {
actionStatsMap.put(entry.getKey(), entry.getValue());
}
}
modelStats.put(modelId, new MLModelStats(actionStatsMap));
}
}
modelStats.put(modelId, new MLModelStats(actionStatsMap));
}
}, e -> { log.error("Search Hidden model wasn't successful"); }), latch);
// Wait for the asynchronous call to complete
try {
latch.await(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
// Handle interruption if necessary
Thread.currentThread().interrupt();
}
}

return new MLStatsNodeResponse(clusterService.localNode(), statValues, algorithmStats, modelStats);
}

@VisibleForTesting
void searchHiddenModels(ActionListener<Set<String>> listener, CountDownLatch latch) {
SearchRequest searchRequest = buildHiddenModelSearchRequest();
// Use a try-with-resources block to ensure resources are properly released
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
// Wrap the listener to restore thread context before calling it
ActionListener<Set<String>> internalListener = ActionListener.runAfter(listener, () -> {
latch.countDown();
threadContext.restore();
});
// Wrap the search response handler to handle success and failure cases
// Notify the listener of any search failures
ActionListener<SearchResponse> al = ActionListener.wrap(response -> {
// Initialize the result set
Set<String> result = new HashSet<>(response.getHits().getHits().length); // Set initial capacity to the number of hits

// Iterate over the search hits and add their IDs to the result set
for (SearchHit hit : response.getHits()) {
result.add(hit.getId());
}
// Notify the listener of the search results
internalListener.onResponse(result);
}, internalListener::onFailure);

// Execute the search request asynchronously
client.search(searchRequest, al);
} catch (Exception e) {
// Notify the listener of any unexpected errors
listener.onFailure(e);
}
}

private SearchRequest buildHiddenModelSearchRequest() {
SearchRequest searchRequest = new SearchRequest(CommonValue.ML_MODEL_INDEX);
// Build the query
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
boolQueryBuilder
.filter(
QueryBuilders
.boolQuery()
.must(QueryBuilders.termQuery(MLModel.IS_HIDDEN_FIELD, true))
// Add the additional filter to exclude documents where "chunk_number" exists
.mustNot(QueryBuilders.existsQuery("chunk_number"))
);
searchRequest.source().query(boolQueryBuilder);
// Specify the fields to include in the search results (only the "_id" field)
// No fields to exclude
searchRequest.source().fetchSource(new String[] { "_id" }, new String[] {});
return searchRequest;
}

@VisibleForTesting
boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) {
return RestActionUtils.isSuperAdminUser(clusterService, client);
}
}
Loading

0 comments on commit 3cc2fca

Please sign in to comment.