Skip to content

Commit

Permalink
not sending failure message when model index isn't present (#2351)
Browse files Browse the repository at this point in the history
* not sending failure message when model index isn't present

Signed-off-by: Dhrubo Saha <[email protected]>

* making profile api experience same

Signed-off-by: Dhrubo Saha <[email protected]>

* add unit test

Signed-off-by: Dhrubo Saha <[email protected]>

* applying spotless

Signed-off-by: Dhrubo Saha <[email protected]>

---------

Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os authored Apr 23, 2024
1 parent 4b26ebf commit be05dfc
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,15 @@ public void onResponse(SearchResponse searchResponse) {

@Override
public void onFailure(Exception e) {
onFailed(channel, "Searching model wasn't successful", e);
try {
builder.startObject();
builder.endObject();
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
} catch (IOException ex) {
String errorMessage = "Failed to get ML node level profile";
log.error(errorMessage, e);
onFailed(channel, errorMessage, e);
}
}

}, threadContext::restore));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,11 @@ public void onResponse(SearchResponse searchResponse) {

@Override
public void onFailure(Exception e) {
onFailed(channel, RestStatus.INTERNAL_SERVER_ERROR, "Searching model wasn't successful", e);
try {
getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel);
} catch (IOException ex) {
onFailed(channel, RestStatus.INTERNAL_SERVER_ERROR, "Failed to retrieve Cluster level metrics", e);
}
}
}, threadContext::restore));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.rest;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;
Expand Down Expand Up @@ -51,6 +52,7 @@
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.action.profile.MLProfileAction;
Expand All @@ -67,6 +69,7 @@
import org.opensearch.ml.profile.MLModelProfile;
import org.opensearch.ml.profile.MLPredictRequestStats;
import org.opensearch.ml.profile.MLProfileInput;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -151,13 +154,6 @@ public void setup() throws IOException {
testState = setupTestClusterState();
when(clusterService.state()).thenReturn(testState);

doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
listener.onResponse(response);
return null;
}).when(client).search(any(SearchRequest.class), any());

doAnswer(invocation -> {
ActionListener<MLProfileResponse> actionListener = invocation.getArgument(2);
Map<String, MLTask> nodeTasks = new HashMap<>();
Expand Down Expand Up @@ -207,6 +203,13 @@ public void testRoutes() {
}

public void test_PrepareRequest_TaskRequest() throws Exception {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
listener.onResponse(response);
return null;
}).when(client).search(any(SearchRequest.class), any());

RestRequest request = getRestRequest();
profileAction.handleRequest(request, channel, client);

Expand All @@ -218,6 +221,13 @@ public void test_PrepareRequest_TaskRequest() throws Exception {
}

public void test_PrepareRequest_TaskRequestWithNoTaskIds() throws Exception {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
listener.onResponse(response);
return null;
}).when(client).search(any(SearchRequest.class), any());

RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withPath("/_plugins/_ml/profile/tasks").build();
profileAction.handleRequest(request, channel, client);

Expand All @@ -228,6 +238,13 @@ public void test_PrepareRequest_TaskRequestWithNoTaskIds() throws Exception {
}

public void test_PrepareRequest_ModelRequest() throws Exception {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
listener.onResponse(response);
return null;
}).when(client).search(any(SearchRequest.class), any());

RestRequest request = getModelRestRequest();
profileAction.handleRequest(request, channel, client);

Expand All @@ -239,6 +256,13 @@ public void test_PrepareRequest_ModelRequest() throws Exception {
}

public void test_PrepareRequest_TaskRequestWithNoModelIds() throws Exception {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
listener.onResponse(response);
return null;
}).when(client).search(any(SearchRequest.class), any());

RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withPath("/_plugins/_ml/profile/models").build();
profileAction.handleRequest(request, channel, client);

Expand All @@ -249,6 +273,12 @@ public void test_PrepareRequest_TaskRequestWithNoModelIds() throws Exception {
}

public void test_PrepareRequest_EmptyNodeProfile() throws Exception {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
listener.onResponse(response);
return null;
}).when(client).search(any(SearchRequest.class), any());
doAnswer(invocation -> {
ActionListener<MLProfileResponse> actionListener = invocation.getArgument(2);
MLProfileResponse profileResponse = new MLProfileResponse(clusterName, new ArrayList<>(), new ArrayList<>());
Expand All @@ -267,6 +297,13 @@ public void test_PrepareRequest_EmptyNodeProfile() throws Exception {
}

public void test_PrepareRequest_EmptyNodeTasksSize() throws Exception {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
listener.onResponse(response);
return null;
}).when(client).search(any(SearchRequest.class), any());

doAnswer(invocation -> {
ActionListener<MLProfileResponse> actionListener = invocation.getArgument(2);
Map<String, MLTask> nodeTasks = new HashMap<>();
Expand All @@ -288,6 +325,13 @@ public void test_PrepareRequest_EmptyNodeTasksSize() throws Exception {
}

public void test_PrepareRequest_WithRequestContent() throws Exception {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
listener.onResponse(response);
return null;
}).when(client).search(any(SearchRequest.class), any());

MLProfileInput mlProfileInput = new MLProfileInput();
RestRequest request = getProfileRestRequest(mlProfileInput);
profileAction.handleRequest(request, channel, client);
Expand All @@ -296,6 +340,13 @@ public void test_PrepareRequest_WithRequestContent() throws Exception {
}

public void test_PrepareRequest_Failure() throws Exception {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
listener.onResponse(response);
return null;
}).when(client).search(any(SearchRequest.class), any());

doAnswer(invocation -> {
ActionListener<MLProfileResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new RuntimeException("test failure"));
Expand All @@ -308,14 +359,84 @@ public void test_PrepareRequest_Failure() throws Exception {
verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any());
}

public void test_Search_Failure() throws Exception {
// Setup to simulate a search failure
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onFailure(new Exception("Mocking Exception")); // Trigger failure
return null;
}).when(client).search(any(SearchRequest.class), any(ActionListener.class));

// Create a RestRequest instance for testing
RestRequest request = getRestRequest(); // Ensure this method correctly initializes a RestRequest

// Handle the request with the expectation of handling a failure
profileAction.handleRequest(request, channel, client);

// Verification that the search method was called exactly once
verify(client, times(1)).search(any(SearchRequest.class), any(ActionListener.class));

// Capturing the response sent to the channel
ArgumentCaptor<BytesRestResponse> responseCaptor = ArgumentCaptor.forClass(BytesRestResponse.class);
verify(channel).sendResponse(responseCaptor.capture());

// Check the response status code to see if it correctly reflects the error
BytesRestResponse response = responseCaptor.getValue();
assertEquals(RestStatus.OK, response.status());
assertTrue(response.content().utf8ToString().contains("{}"));
}

public void test_WhenViewIsModel_ReturnModelViewResult() throws Exception {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
listener.onResponse(response);
return null;
}).when(client).search(any(SearchRequest.class), any());
MLProfileInput mlProfileInput = new MLProfileInput();
RestRequest request = getProfileRestRequestWithQueryParams(mlProfileInput, ImmutableMap.of("view", "model"));
profileAction.handleRequest(request, channel, client);
ArgumentCaptor<MLProfileRequest> argumentCaptor = ArgumentCaptor.forClass(MLProfileRequest.class);
verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any());
}

// public void testNodeViewOutput() throws Exception {
// // Assuming setup for non-empty node responses as done in the initial setup
// MLProfileInput mlProfileInput = new MLProfileInput();
// RestRequest request = getProfileRestRequestWithQueryParams(mlProfileInput, ImmutableMap.of("view", "node"));
// profileAction.handleRequest(request, channel, client);
//
// ArgumentCaptor<MLProfileRequest> argumentCaptor = ArgumentCaptor.forClass(MLProfileRequest.class);
// verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any());
//
// // Verify that the response is correctly formed for the node view
// verify(channel).sendResponse(argThat(response -> {
// // Ensure the response content matches expected node view structure
// String content = response.content().utf8ToString();
// return content.contains("\"node\":") && !content.contains("\"models\":");
// }));
// }

public void testBackendFailureHandling() throws Exception {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
listener.onResponse(response);
return null;
}).when(client).search(any(SearchRequest.class), any());

doAnswer(invocation -> {
ActionListener<MLProfileResponse> listener = invocation.getArgument(2);
listener.onFailure(new RuntimeException("Simulated backend failure"));
return null;
}).when(client).execute(eq(MLProfileAction.INSTANCE), any(MLProfileRequest.class), any(ActionListener.class));

RestRequest request = getRestRequest();
profileAction.handleRequest(request, channel, client);

verify(channel).sendResponse(argThat(response -> response.status() == RestStatus.INTERNAL_SERVER_ERROR));
}

private SearchResponse createSearchModelResponse() throws IOException {
XContentBuilder content = builder();
content.startObject();
Expand Down

0 comments on commit be05dfc

Please sign in to comment.