From 246e8324c8531e3821d3498e15477160874fe99a Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 28 Sep 2023 17:34:26 +0100 Subject: [PATCH] [ML] Handle malformed inference result (#100023) Improved logging and best attempt at handling what is expected to be a rare edge case --- .../process/PyTorchResultProcessor.java | 38 +++++++++++++++---- .../process/PyTorchResultProcessorTests.java | 17 +++++++++ 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java index 6b92a9349c4ea..4b925464d985b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java @@ -107,17 +107,19 @@ public void process(PyTorchProcess process) { if (result.inferenceResult() != null) { processInferenceResult(result); - } - ThreadSettings threadSettings = result.threadSettings(); - if (threadSettings != null) { - threadSettingsConsumer.accept(threadSettings); + } else if (result.threadSettings() != null) { + threadSettingsConsumer.accept(result.threadSettings()); processThreadSettings(result); - } - if (result.ackResult() != null) { + } else if (result.ackResult() != null) { processAcknowledgement(result); - } - if (result.errorResult() != null) { + } else if (result.errorResult() != null) { processErrorResult(result); + } else { + // will should only get here if the native process + // has produced a partially valid result, one that + // is accepted by the parser but does not have any + // content + handleUnknownResultType(result); } } } catch (Exception e) { @@ -208,6 +210,26 @@ void processErrorResult(PyTorchResult result) { } } + void handleUnknownResultType(PyTorchResult result) { + if (result.requestId() != null) { + PendingResult pendingResult = pendingResults.remove(result.requestId()); + if (pendingResult == null) { + logger.error(() -> format("[%s] no pending result listener for unknown result type [%s]", modelId, result)); + } else { + String msg = format("[%s] pending result listener cannot handle unknown result type [%s]", modelId, result); + logger.error(msg); + var errorResult = new ErrorResult(msg); + pendingResult.listener.onResponse(new PyTorchResult(result.requestId(), null, null, null, null, null, errorResult)); + } + } else { + // Cannot look up the listener without a request id + // all that can be done in this case is log a message. + // The result parser requires a request id so this + // code should not be hit. + logger.error(() -> format("[%s] cannot process unknown result type [%s]", modelId, result)); + } + } + public synchronized ResultStats getResultStats() { long currentMs = currentTimeMsSupplier.getAsLong(); long currentPeriodStartTimeMs = startTime + Intervals.alignToFloor(currentMs - startTime, REPORTING_PERIOD_MS); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java index e172f4ffb528c..860da3140f4fe 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java @@ -153,6 +153,23 @@ public void testPendingRequestAreCalledAtShutdown() { } } + public void testsHandleUnknownResult() { + var processor = new PyTorchResultProcessor("deployment-foo", settings -> {}); + var listener = new AssertingResultListener( + r -> assertThat( + r.errorResult().error(), + containsString("[deployment-foo] pending result listener cannot handle unknown result type") + ) + ); + + processor.registerRequest("no-result-content", listener); + + processor.process( + mockNativeProcess(List.of(new PyTorchResult("no-result-content", null, null, null, null, null, null)).iterator()) + ); + assertTrue(listener.hasResponse); + } + private static class AssertingResultListener implements ActionListener { boolean hasResponse; final Consumer responseAsserter;