diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index b76a39a0f2ac2..86bf4a925121a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -394,6 +394,16 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons for (var entry : response.responses.entrySet()) { var fieldName = entry.getKey(); var responses = entry.getValue(); + if (responses == null) { + if (item.request() instanceof UpdateRequest == false) { + // could be an assert + throw new IllegalArgumentException( + "Inference results can only be cleared for update requests where a field is explicitly set to null." + ); + } + inferenceFieldsMap.put(fieldName, null); + continue; + } var model = responses.get(0).model(); // ensure that the order in the original field is consistent in case of multiple inputs Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); @@ -480,6 +490,7 @@ private Map> createFieldInferenceRequests(Bu } final Map docMap = indexRequest.sourceAsMap(); + Object explicitNull = new Object(); for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); String inferenceId = entry.getInferenceId(); @@ -487,10 +498,11 @@ private Map> createFieldInferenceRequests(Bu if (useInferenceMetadataFieldsFormat) { var inferenceMetadataFieldsValue = XContentMapValues.extractValue( InferenceMetadataFieldsMapper.NAME + "." + field, - docMap + docMap, + explicitNull ); if (inferenceMetadataFieldsValue != null) { - // Inference has already been computed + // Inference has already been computed for this source field continue; } } else { @@ -503,9 +515,20 @@ private Map> createFieldInferenceRequests(Bu int order = 0; for (var sourceField : entry.getSourceFields()) { - // TODO: Detect when the field is provided with an explicit null value - var valueObj = XContentMapValues.extractValue(sourceField, docMap); - if (valueObj == null) { + var valueObj = XContentMapValues.extractValue(sourceField, docMap, explicitNull); + if (useInferenceMetadataFieldsFormat && isUpdateRequest && valueObj == explicitNull) { + /** + * It's an update request, and the source field is explicitly set to null, + * so we need to propagate this information to the inference fields metadata + * to overwrite any inference previously computed on the field. + * This ensures that the field is treated as intentionally cleared, + * preventing any unintended carryover of prior inference results. + */ + var slot = ensureResponseAccumulatorSlot(itemIndex); + slot.responses.put(sourceField, null); + continue; + } + if (valueObj == null || valueObj == explicitNull) { if (isUpdateRequest && (useInferenceMetadataFieldsFormat == false)) { addInferenceResponseFailure( item.id(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 73fe792664071..d0e5b428f3422 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -212,6 +213,11 @@ public void testItemFailures() throws Exception { ), equalTo("I am a success") ); + if (useInferenceMetadataFieldsFormat) { + assertNotNull( + XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME + ".field1", actualRequest.sourceAsMap()) + ); + } // item 2 is a failure assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse()); @@ -239,6 +245,85 @@ public void testItemFailures() throws Exception { awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testExplicitNull() throws Exception { + StaticModel model = StaticModel.createRandomInstance(); + + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(model.getInferenceEntityId(), model), + randomIntBetween(1, 10), + IndexVersion.current() + ); + model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); + model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success"))); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getInferenceFieldMap()); + assertThat(bulkShardRequest.items().length, equalTo(4)); + + Object explicitNull = new Object(); + // item 0 + assertNull(bulkShardRequest.items()[0].getPrimaryResponse()); + IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[0].request()); + assertTrue(XContentMapValues.extractValue("field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull); + assertNull(XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME, actualRequest.sourceAsMap(), explicitNull)); + + // item 1 is a success + assertNull(bulkShardRequest.items()[1].getPrimaryResponse()); + actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request()); + assertThat(XContentMapValues.extractValue("field1", actualRequest.sourceAsMap()), equalTo("I am a success")); + assertNotNull( + XContentMapValues.extractValue( + InferenceMetadataFieldsMapper.NAME + ".field1", + actualRequest.sourceAsMap(), + explicitNull + ) + ); + + // item 2 is a failure + assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse()); + assertTrue(bulkShardRequest.items()[2].getPrimaryResponse().isFailed()); + var failure = bulkShardRequest.items()[2].getPrimaryResponse().getFailure(); + assertThat(failure.getCause().getCause().getMessage(), containsString("boom")); + + // item 3 + assertNull(bulkShardRequest.items()[3].getPrimaryResponse()); + actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[3].request()); + assertTrue(XContentMapValues.extractValue("field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull); + assertTrue( + XContentMapValues.extractValue( + InferenceMetadataFieldsMapper.NAME + ".field1", + actualRequest.sourceAsMap(), + explicitNull + ) == explicitNull + ); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "field1", + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }) + ); + BulkItemRequest[] items = new BulkItemRequest[4]; + Map sourceWithNull = new HashMap<>(); + sourceWithNull.put("field1", null); + items[0] = new BulkItemRequest(0, new IndexRequest("index").source(sourceWithNull)); + items[1] = new BulkItemRequest(1, new IndexRequest("index").source("field1", "I am a success")); + items[2] = new BulkItemRequest(2, new IndexRequest("index").source("field1", "I am a failure")); + items[3] = new BulkItemRequest(3, new UpdateRequest().doc(new IndexRequest("index").source(sourceWithNull))); + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + @SuppressWarnings({ "unchecked", "rawtypes" }) public void testManyRandomDocs() throws Exception { IndexVersion indexVersion = getRandomIndexVersion();