Skip to content

Commit

Permalink
Changed error message handling
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 6, 2023
1 parent e592ad7 commit 5f47156
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -42,7 +43,8 @@
public class MLCommonsClientAccessor {
private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding");
private final MachineLearningNodeClient mlClient;
private static final String EXCEPTION_MESSAGE_MODEL_PROCESSING_FAILED = "the system encountered an unexpected error during processing";
private static final String EXCEPTION_MESSAGE_MODEL_PREDICT_FAILED = "failed while calling model, check error log for details";
private static final String EXCEPTION_MESSAGE_PREFIX_MODEL_PREDICT_FAILED = "encountered following error while calling a model";

/**
* Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating
Expand Down Expand Up @@ -190,12 +192,19 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
String exceptionMessage = EXCEPTION_MESSAGE_MODEL_PROCESSING_FAILED;
if (Objects.isNull(tensor.getData())) {
if (Objects.nonNull(tensor.getDataAsMap()) && Strings.isNotBlank((String) tensor.getDataAsMap().get("message"))) {
exceptionMessage = (String) tensor.getDataAsMap().get("message");
String errorFromModel = (String) tensor.getDataAsMap().get("message");
throw new IllegalStateException(
String.format(Locale.ROOT, "%s: %s", EXCEPTION_MESSAGE_PREFIX_MODEL_PREDICT_FAILED, errorFromModel)
);
} else {
log.error(
"Received following output tensor from a model, there is no detailed error message: {}",
tensor.toString()
);
throw new IllegalStateException(EXCEPTION_MESSAGE_MODEL_PREDICT_FAILED);
}
throw new IllegalStateException(exceptionMessage);
}
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.neuralsearch.ml;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;

Expand Down Expand Up @@ -330,9 +331,53 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR
}

public void testInferenceMultimodal_whenInvalidInputAndEmptyTensorOutput_thenFail() {
List<ModelTensors> tensorsList = new ArrayList<>();
List<ModelTensor> mlModelTensorList = List.of(
new ModelTensor(
"someValue",
null,
new long[] { 1, 2 },
MLResultDataType.FLOAT64,
ByteBuffer.wrap(new byte[12]),
"mockResult",
ImmutableMap.of("message", "The system encountered an unexpected error during processing. Try your request again.")
)
);
final ModelTensors modelTensors = new ModelTensors(mlModelTensorList);
ModelTensorOutput outputWithErrorMessage = new ModelTensorOutput(List.of(modelTensors));

Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(outputWithErrorMessage);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(any());
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);

clearInvocations(client, singleSentenceResultListener);

List<ModelTensor> mlModelTensorList2 = List.of(
new ModelTensor(
"someValue",
null,
new long[] { 1, 2 },
MLResultDataType.FLOAT64,
ByteBuffer.wrap(new byte[12]),
"mockResult",
ImmutableMap.of("test_key", "test_value")
)
);
final ModelTensors modelTensors2 = new ModelTensors(mlModelTensorList2);
ModelTensorOutput outputWithErrorMessage2 = new ModelTensorOutput(List.of(modelTensors2));

Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createEmptyModelTensorOutput());
actionListener.onResponse(outputWithErrorMessage2);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

Expand Down Expand Up @@ -371,22 +416,4 @@ private ModelTensorOutput createModelTensorOutput(final Map<String, String> map)
tensorsList.add(modelTensors);
return new ModelTensorOutput(tensorsList);
}

private ModelTensorOutput createEmptyModelTensorOutput() {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
final ModelTensor tensor = new ModelTensor(
"someValue",
null,
new long[] { 1, 2 },
MLResultDataType.FLOAT64,
ByteBuffer.wrap(new byte[12]),
"mockResult",
ImmutableMap.of("message", "The system encountered an unexpected error during processing. Try your request again.")
);
mlModelTensorList.add(tensor);
final ModelTensors modelTensors = new ModelTensors(mlModelTensorList);
tensorsList.add(modelTensors);
return new ModelTensorOutput(tensorsList);
}
}

0 comments on commit 5f47156

Please sign in to comment.